multiple changes
- altered db schema to delete subscriptions along with topics - extracted logic for creating and deleting topics into separate functions - added functions to list topics and subscriptions - delete topic and corresponding subscription if same topic is added with different token - added more error messages and reactions - added handling for authentication errors (subscription and topic are removed) - inform all subscribed rooms when subscription is cancelled due to an action in another room - added documentation for methods - code style
This commit is contained in:
parent
a55124c23a
commit
92187280e9
610
ntfy/bot.py
610
ntfy/bot.py
|
|
@ -1,33 +1,70 @@
|
|||
import asyncio
|
||||
import html
|
||||
from html import escape
|
||||
import json
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Tuple, List, Awaitable, Callable
|
||||
|
||||
from aiohttp import ClientTimeout
|
||||
from maubot import MessageEvent, Plugin
|
||||
from maubot.handlers import command
|
||||
from mautrix.types import (EventType, Format, MessageType,
|
||||
TextMessageEventContent)
|
||||
TextMessageEventContent, RoomID, EventID)
|
||||
from mautrix.util.async_db import UpgradeTable
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
from mautrix.util.formatter import parse_html
|
||||
|
||||
from .config import Config
|
||||
from .db import DB, Topic, upgrade_table
|
||||
from .emoji import EMOJI_FALLBACK, WHITE_CHECK_MARK, parse_tags
|
||||
from .emoji import (EMOJI_FALLBACK, WHITE_CHECK_MARK, REPEAT, NO_ENTRY,
|
||||
WARNING, parse_tags)
|
||||
from .exceptions import SubscriptionError
|
||||
|
||||
|
||||
async def build_notice(html_content) -> TextMessageEventContent:
|
||||
"""
|
||||
Build a notice message.
|
||||
"""
|
||||
text_content = await parse_html(html_content.strip())
|
||||
return TextMessageEventContent(
|
||||
msgtype=MessageType.NOTICE,
|
||||
format=Format.HTML,
|
||||
formatted_body=html_content,
|
||||
body=text_content,
|
||||
)
|
||||
|
||||
|
||||
def ensure_permission(func: Callable):
|
||||
"""
|
||||
Decorator function to ensure that the user has the required permission
|
||||
to execute a command.
|
||||
"""
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
evt = args[0]
|
||||
if evt.sender in self.config["admins"]:
|
||||
return await func(self, *args, **kwargs)
|
||||
levels = await self.client.get_state_event(evt.room_id,
|
||||
EventType.ROOM_POWER_LEVELS)
|
||||
user_level = levels.get_user_level(evt.sender)
|
||||
notice = await build_notice(
|
||||
"You don't have the permission to manage the ntfy bot.")
|
||||
await evt.reply(notice)
|
||||
return
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class NtfyBot(Plugin):
|
||||
db: DB
|
||||
config: Config
|
||||
tasks: Dict[int, asyncio.Task] = {}
|
||||
reactions: Dict[RoomID, Dict[str, Dict[str, EventID]]] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
await super().start()
|
||||
self.config.load_and_update()
|
||||
self.db = DB(self.database, self.log)
|
||||
if EMOJI_FALLBACK:
|
||||
self.log.warn(
|
||||
self.log.warning(
|
||||
"Please install the `emoji` package for full emoji support")
|
||||
await self.resubscribe()
|
||||
|
||||
|
|
@ -40,26 +77,30 @@ class NtfyBot(Plugin):
|
|||
self.config.load_and_update()
|
||||
|
||||
async def resubscribe(self) -> None:
|
||||
"""
|
||||
Clear all subscriptions and resubscribe to all topics.
|
||||
:return:
|
||||
"""
|
||||
await self.clear_subscriptions()
|
||||
await self.subscribe_to_topics()
|
||||
|
||||
async def clear_subscriptions(self) -> None:
|
||||
tasks = list(self.tasks.values())
|
||||
if not tasks:
|
||||
return None
|
||||
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
self.log.debug("cancelling subscription task...")
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self.log.exception("Subscription task errored", exc_info=exc)
|
||||
self.tasks.clear()
|
||||
"""
|
||||
Cancel all subscription tasks and clear the task list.
|
||||
:return:
|
||||
"""
|
||||
for topic_id in self.tasks.keys():
|
||||
await self.cancel_topic_subscription_task(topic_id)
|
||||
|
||||
async def cancel_topic_subscription_task(self, topic_id: int):
|
||||
"""
|
||||
Cancel a topic subscription task and handle any exceptions.
|
||||
:param topic_id: The topic id to cancel the subscription task for
|
||||
:return:
|
||||
"""
|
||||
task = self.tasks.get(topic_id)
|
||||
if task and not task.done():
|
||||
self.log.debug("Cancelling subscription task for topic %d",
|
||||
async def can_use_command(self, evt: MessageEvent) -> bool:
|
||||
if evt.sender in self.config["admins"]:
|
||||
return True
|
||||
|
|
@ -70,123 +111,505 @@ class NtfyBot(Plugin):
|
|||
return False
|
||||
return True
|
||||
|
||||
@command.new(name=lambda self: self.config["command_prefix"], help="Manage ntfy subscriptions.", require_subcommand=True)
|
||||
@command.new(name=lambda self: self.config["command_prefix"],
|
||||
help="Manage ntfy subscriptions.", require_subcommand=True)
|
||||
async def ntfy(self) -> None:
|
||||
pass
|
||||
|
||||
@ntfy.subcommand("subscribe", aliases=("sub",), help="Subscribe this room to a ntfy topic.")
|
||||
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
@command.argument("token", "access token", required=False, matches="tk_[a-zA-Z0-9]{29}")
|
||||
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any], token: Tuple[str, Any]) -> None:
|
||||
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
|
||||
if not await self.can_use_command(evt):
|
||||
return None
|
||||
@ntfy.subcommand("list-topics", aliases=("topics",),
|
||||
help="List all ntfy topics in the database.")
|
||||
@ensure_permission
|
||||
async def list_topics(self, evt: MessageEvent) -> None:
|
||||
"""
|
||||
List all topics in the database.
|
||||
:param evt: The message event
|
||||
:return: None
|
||||
"""
|
||||
# Get all topics from the database
|
||||
topics = await self.db.get_topics()
|
||||
|
||||
# Build the message content
|
||||
if topics:
|
||||
html_content = "<span>Currently available ntfy topics:</span><ul>"
|
||||
for topic in topics:
|
||||
html_content += "<li><code>%s/%s</code></li>" % (
|
||||
escape(topic.server), escape(topic.topic))
|
||||
html_content += "</ul>"
|
||||
else:
|
||||
html_content = "<span>No topics available yet.</span>"
|
||||
|
||||
# Send the message to the user
|
||||
content = await build_notice(html_content)
|
||||
await evt.reply(content)
|
||||
|
||||
@ntfy.subcommand("add-topic", aliases=("add",),
|
||||
help="Add a new topic to the database.")
|
||||
@command.argument("topic", "topic URL",
|
||||
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
|
||||
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
@command.argument("token", "access token", required=False,
|
||||
matches="tk_[a-zA-Z0-9]{29}")
|
||||
@ensure_permission
|
||||
async def add_topic(self, evt: MessageEvent, topic: Tuple[str, Any],
|
||||
token: Tuple[str, Any]) -> None:
|
||||
"""
|
||||
Add a new topic to the database.
|
||||
:param evt: The message event
|
||||
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
|
||||
:param token: The access token for the topic
|
||||
(e.g. `tk_be6uc2ca1x1orakcwg1j3hp0ylot6`)
|
||||
:return: None
|
||||
"""
|
||||
# Look up the topic in the database
|
||||
server, topic = topic[0].split("/")
|
||||
token = token if token else None
|
||||
db_topic = await self.db.get_topic(server, topic)
|
||||
|
||||
# Check if the topic already exists
|
||||
if db_topic:
|
||||
# If it already exists, check if the token is the same
|
||||
if db_topic.token == token:
|
||||
await self.add_reaction(evt, NO_ENTRY)
|
||||
notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> already exists." %
|
||||
(escape(server), escape(topic)))
|
||||
await evt.reply(notice)
|
||||
return
|
||||
# Else remove the topic and delete all corresponding subscriptions
|
||||
else:
|
||||
subscriptions = await self.db.get_subscriptions_by_topic_id(
|
||||
db_topic.id)
|
||||
await self.cancel_topic_subscription_task(db_topic.id)
|
||||
await self.db.remove_topic(db_topic.id)
|
||||
db_topic = None
|
||||
await self.add_reaction(evt, WARNING)
|
||||
notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> already exists "
|
||||
"with a different token. The old entry with the different "
|
||||
"token and any associated subscriptions have been removed "
|
||||
"and would need to be re-subscribed." %
|
||||
(escape(server), escape(topic)))
|
||||
# Inform user about the result
|
||||
await evt.reply(notice)
|
||||
# Broadcast a notice to all subscribed rooms
|
||||
broadcast_notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> has been updated with a new "
|
||||
"access token. Please re-subscribe to this topic." %
|
||||
(escape(server), escape(topic)))
|
||||
room_ids = [sub.room_id for sub in subscriptions
|
||||
if sub.room_id != evt.room_id]
|
||||
await self.broadcast_to_rooms(room_ids, broadcast_notice)
|
||||
|
||||
# Create the topic if it doesn't exist
|
||||
db_topic = await self.db.create_topic(
|
||||
Topic(id=-1, server=server, topic=topic, token=token,
|
||||
last_event_id=None))
|
||||
|
||||
# Mark the command as successful
|
||||
await self.add_reaction(evt, WHITE_CHECK_MARK)
|
||||
|
||||
@ntfy.subcommand("remove-topic", aliases=("remove",),
|
||||
help="Remove a topic from the database.")
|
||||
@command.argument("topic", "topic URL",
|
||||
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
|
||||
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
@ensure_permission
|
||||
async def remove_topic(self, evt: MessageEvent,
|
||||
topic: Tuple[str, Any]) -> None:
|
||||
"""
|
||||
Remove a topic from the database.
|
||||
:param evt: The message event
|
||||
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
|
||||
:return: None
|
||||
"""
|
||||
# Lookup topic from db
|
||||
server, topic = topic[0].split("/")
|
||||
db_topic = await self.db.get_topic(server, topic)
|
||||
|
||||
# Delete the topic if it exists and cancel all corresponding
|
||||
# subscriptions
|
||||
if db_topic:
|
||||
subscriptions = await self.db.get_subscriptions_by_topic_id(
|
||||
db_topic.id)
|
||||
await self.cancel_topic_subscription_task(db_topic.id)
|
||||
await self.db.remove_topic(db_topic.id)
|
||||
await self.add_reaction(evt, WHITE_CHECK_MARK)
|
||||
notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> has been deleted from the "
|
||||
"database and all associated subscriptions have been "
|
||||
"cancelled." % (escape(server), escape(topic)))
|
||||
# Broadcast a notice to all subscribed rooms
|
||||
broadcast_notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> has been removed from the "
|
||||
"database. The room has been unsubscribed from this topic." %
|
||||
(escape(server), escape(topic)))
|
||||
room_ids = [sub.room_id for sub in subscriptions
|
||||
if sub.room_id != evt.room_id]
|
||||
await self.broadcast_to_rooms(room_ids, broadcast_notice)
|
||||
|
||||
# Inform the user if the topic doesn't exist
|
||||
else:
|
||||
await self.add_reaction(evt, NO_ENTRY)
|
||||
notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> does not exist in the database." %
|
||||
(escape(server), escape(topic)))
|
||||
|
||||
# Inform user about the result
|
||||
await evt.reply(notice)
|
||||
|
||||
@ntfy.subcommand("list-subscriptions", aliases=("subscriptions",),
|
||||
help="List all active subscriptions for this room.")
|
||||
@ensure_permission
|
||||
async def list_subscriptions(self, evt: MessageEvent) -> None:
|
||||
"""
|
||||
List all active subscriptions for this room.
|
||||
:param evt: The message event
|
||||
:return: None
|
||||
"""
|
||||
# Get all subscriptions for this room from the database
|
||||
subscriptions = await self.db.get_subscriptions_by_room_id(evt.room_id)
|
||||
|
||||
# Build the message content
|
||||
html_content = ("<span>This room is currently subscribed to these "
|
||||
"topics:</span><ul>")
|
||||
for sub in subscriptions:
|
||||
html_content += "<li><code>%s/%s</code></li>" % (
|
||||
escape(sub.server), escape(sub.topic))
|
||||
html_content += "</ul>"
|
||||
|
||||
# Send the message to the user
|
||||
content = await build_notice(html_content)
|
||||
await evt.reply(content)
|
||||
|
||||
@ntfy.subcommand("subscribe", aliases=("sub",),
|
||||
help="Subscribe this room to a ntfy topic.")
|
||||
@command.argument("topic", "topic URL",
|
||||
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
|
||||
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
@ensure_permission
|
||||
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]):
|
||||
"""
|
||||
Subscribe this room to a ntfy topic.
|
||||
See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61
|
||||
for the valid topic regex.
|
||||
:param evt:
|
||||
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
|
||||
:return: None
|
||||
"""
|
||||
# Check if the topic already exists
|
||||
server, topic = topic[0].split("/")
|
||||
db_topic = await self.db.get_topic(server, topic)
|
||||
|
||||
# If it doesn't exist, tell user to add it first
|
||||
if not db_topic:
|
||||
db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, token=token, last_event_id=None))
|
||||
existing_subscriptions = await self.db.get_subscriptions(db_topic.id)
|
||||
await self.add_reaction(evt, NO_ENTRY)
|
||||
notice = await build_notice(
|
||||
"The topic <code>%s/%s</code> does not exist in the"
|
||||
"database. Please add it first: <code>%s add-subscription "
|
||||
"\\<topic URL\\> [access token]</code>" %
|
||||
(escape(server), escape(topic), self.config["command_prefix"]))
|
||||
await evt.reply(notice)
|
||||
return
|
||||
|
||||
# Check if the room is already subscribed to the topic
|
||||
existing_subscriptions = await self.db.get_subscriptions_by_topic_id(
|
||||
db_topic.id)
|
||||
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
|
||||
if sub:
|
||||
await evt.reply("This room is already subscribed to %s/%s", server, topic)
|
||||
notice = await build_notice(
|
||||
"This room is already subscribed to <code>%s/%s</code>." %
|
||||
(escape(server), escape(topic)))
|
||||
await evt.reply(notice)
|
||||
|
||||
# Subscribe the room to the topic if it isn't already subscribed
|
||||
else:
|
||||
await self.db.add_subscription(db_topic.id, evt.room_id)
|
||||
await evt.reply("Subscribed this room to %s/%s", server, topic)
|
||||
await evt.react(WHITE_CHECK_MARK)
|
||||
if not existing_subscriptions:
|
||||
self.subscribe_to_topic(db_topic)
|
||||
self.subscribe_to_topic(db_topic, evt)
|
||||
|
||||
@ntfy.subcommand("unsubscribe", aliases=("unsub",), help="Unsubscribe this room from a ntfy topic.")
|
||||
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
async def unsubscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None:
|
||||
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
|
||||
if not await self.can_use_command(evt):
|
||||
return None
|
||||
# Mark the command as successful
|
||||
await self.add_reaction(evt, WHITE_CHECK_MARK)
|
||||
|
||||
@ntfy.subcommand("unsubscribe", aliases=("unsub",),
|
||||
help="Unsubscribe this room from a ntfy topic.")
|
||||
@command.argument("topic", "topic URL",
|
||||
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
|
||||
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
@ensure_permission
|
||||
async def unsubscribe(self, evt: MessageEvent,
|
||||
topic: Tuple[str, Any]) -> None:
|
||||
"""
|
||||
Unsubscribe this room from a ntfy topic.
|
||||
See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex.
|
||||
:param evt:
|
||||
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
|
||||
:return:
|
||||
"""
|
||||
# Check if the topic exists
|
||||
server, topic = topic[0].split("/")
|
||||
db_topic = await self.db.get_topic(server, topic)
|
||||
notice = await build_notice(
|
||||
"This room is not subscribed to <code>%s/%s</code>." %
|
||||
(escape(db_topic.server), escape(db_topic.topic)))
|
||||
if not db_topic:
|
||||
await evt.reply("This room is not subscribed to %s/%s", server, topic)
|
||||
await evt.reply(notice)
|
||||
return
|
||||
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
|
||||
if not sub:
|
||||
await evt.reply("This room is not subscribed to %s/%s", server, topic)
|
||||
await evt.reply(notice)
|
||||
return
|
||||
|
||||
subscriptions = await self.db.get_subscriptions_by_topic_id(db_topic.id)
|
||||
|
||||
# Unsubscribe the room from the topic
|
||||
await self.db.remove_subscription(db_topic.id, evt.room_id)
|
||||
if not await self.db.get_subscriptions(db_topic.id):
|
||||
self.tasks[db_topic.id].cancel()
|
||||
|
||||
# Cancel the subscription task if there are no more subscriptions
|
||||
if len(subscriptions) <= 1:
|
||||
await self.cancel_topic_subscription_task(db_topic.id)
|
||||
|
||||
# Clear the last event id if there are no more subscriptions
|
||||
if not await self.db.get_subscriptions_by_topic_id(db_topic.id):
|
||||
await self.db.clear_topic_id(db_topic.id)
|
||||
await evt.reply("Unsubscribed this room from %s/%s", server, topic)
|
||||
await evt.react(WHITE_CHECK_MARK)
|
||||
|
||||
# Notify the user that the room has been unsubscribed
|
||||
notice = await build_notice(
|
||||
"Unsubscribed this room from <code>%s/%s</code>." %
|
||||
(escape(db_topic.server), escape(db_topic.topic)))
|
||||
await evt.reply(notice)
|
||||
await self.add_reaction(evt, WHITE_CHECK_MARK)
|
||||
|
||||
async def subscribe_to_topics(self) -> None:
|
||||
"""
|
||||
Subscribe to all topics in the database.
|
||||
:return:
|
||||
"""
|
||||
topics = await self.db.get_topics()
|
||||
for topic in topics:
|
||||
self.subscribe_to_topic(topic)
|
||||
|
||||
def subscribe_to_topic(self, topic: Topic) -> None:
|
||||
def log_task_exc(task: asyncio.Task) -> None:
|
||||
def subscribe_to_topic(self, topic: Topic,
|
||||
evt: MessageEvent | None = None) -> None:
|
||||
"""
|
||||
Subscribe to a topic.
|
||||
:param topic:
|
||||
:param evt:
|
||||
:return:
|
||||
"""
|
||||
|
||||
# Create a callback to log the exception if the task fails
|
||||
def log_task_exc(t: asyncio.Task) -> None:
|
||||
t2 = self.tasks.pop(topic.id, None)
|
||||
if t2 != task:
|
||||
self.log.warn("stored task doesn't match callback")
|
||||
if task.done() and not task.cancelled():
|
||||
exc = task.exception()
|
||||
if t2 != t:
|
||||
self.log.warning("Stored task doesn't match callback")
|
||||
# Log the exception if the task failed
|
||||
if t.done() and not t.cancelled():
|
||||
exc = t.exception()
|
||||
# Handle subscription errors
|
||||
if isinstance(exc, SubscriptionError):
|
||||
self.log.exception("Failed to subscribe to %s/%s",
|
||||
topic.server, topic.topic)
|
||||
self.tasks[topic.id] = self.loop.create_task(
|
||||
self.handle_subscription_error(exc, topic, evt, [
|
||||
self.remove_reaction(evt, WHITE_CHECK_MARK),
|
||||
self.add_reaction(evt, NO_ENTRY),
|
||||
])
|
||||
)
|
||||
else:
|
||||
# Try to resubscribe if the task errored
|
||||
if evt:
|
||||
self.loop.create_task(self.add_reaction(evt, REPEAT))
|
||||
self.log.exception(
|
||||
"Subscription task errored, resubscribing", exc_info=exc)
|
||||
"Subscription task errored, resubscribing",
|
||||
exc_info=exc)
|
||||
# Sleep for 10 seconds before resubscribing
|
||||
self.tasks[topic.id] = self.loop.create_task(
|
||||
asyncio.sleep(10.0))
|
||||
# Subscribe to the topic again
|
||||
self.tasks[topic.id].add_done_callback(
|
||||
lambda _: self.subscribe_to_topic(topic))
|
||||
lambda _: self.subscribe_to_topic(topic, evt))
|
||||
|
||||
self.log.info("Subscribing to %s/%s", topic.server, topic.topic)
|
||||
# Prepare the URL for the topic
|
||||
self.log.info("Subscribing to %s/%s" % (topic.server, topic.topic))
|
||||
url = "%s/%s/json" % (topic.server, topic.topic)
|
||||
# Prepend the URL with `https://` if necessary
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
# Add the last event id to the URL if it exists
|
||||
if topic.last_event_id:
|
||||
url += "?since=%s" % topic.last_event_id
|
||||
|
||||
# Subscribe to the topic
|
||||
self.log.debug("Subscribing to URL %s", url)
|
||||
task = self.loop.create_task(
|
||||
self.run_topic_subscription(topic, url))
|
||||
self.run_topic_subscription(topic, url, evt))
|
||||
self.tasks[topic.id] = task
|
||||
task.add_done_callback(log_task_exc)
|
||||
|
||||
async def run_topic_subscription(self, topic: Topic, url: str) -> None:
|
||||
headers = {"Authorization": f"Bearer {topic.token}"}
|
||||
async def run_topic_subscription(self, topic: Topic, url: str,
|
||||
evt: MessageEvent | None = None) -> None:
|
||||
"""
|
||||
Subscribe to a topic.
|
||||
:param topic:
|
||||
:param url:
|
||||
:param evt:
|
||||
:return:
|
||||
"""
|
||||
# Prepare authentication headers in case a token is provided
|
||||
headers = {
|
||||
"Authorization": f"Bearer {topic.token}"} if topic.token else None
|
||||
|
||||
# Subscribe to the topic
|
||||
async with self.http.get(url, timeout=ClientTimeout(),
|
||||
headers=headers if topic.token else None) as resp:
|
||||
headers=headers) as resp:
|
||||
|
||||
# Loop through the response content
|
||||
while True:
|
||||
line = await resp.content.readline()
|
||||
# convert to string and remove trailing newline
|
||||
|
||||
# Convert to string and remove trailing newline
|
||||
line = line.decode("utf-8").strip()
|
||||
|
||||
# Break if the line is empty
|
||||
if not line:
|
||||
break
|
||||
|
||||
# Parse the line as JSON
|
||||
self.log.trace("Received notification: %s", line)
|
||||
message = json.loads(line)
|
||||
if message["event"] != "message":
|
||||
|
||||
# Check if the message is an event
|
||||
if "event" not in message or message["event"] != "message":
|
||||
# Check if the message is an error, else continue
|
||||
if "error" in message:
|
||||
raise SubscriptionError(message)
|
||||
continue
|
||||
self.log.debug("Received message event: %s", line)
|
||||
# persist the received message id
|
||||
|
||||
# Persist the received last event id
|
||||
await self.db.update_topic_id(topic.id, message["id"])
|
||||
|
||||
# build matrix message
|
||||
# Build matrix message
|
||||
html_content = self.build_message_content(
|
||||
topic.server, message)
|
||||
text_content = await parse_html(html_content.strip())
|
||||
|
||||
content = TextMessageEventContent(
|
||||
msgtype=MessageType.TEXT,
|
||||
msgtype=MessageType.NOTICE,
|
||||
format=Format.HTML,
|
||||
formatted_body=html_content,
|
||||
body=text_content,
|
||||
)
|
||||
|
||||
subscriptions = await self.db.get_subscriptions(topic.id)
|
||||
for sub in subscriptions:
|
||||
# Broadcast the message to all subscribed rooms
|
||||
subscriptions = await self.db.get_subscriptions_by_topic_id(
|
||||
topic.id)
|
||||
room_ids = [sub.room_id for sub in subscriptions]
|
||||
await self.broadcast_to_rooms(room_ids, content)
|
||||
|
||||
async def broadcast_to_rooms(self, room_ids: List[RoomID],
|
||||
content: TextMessageEventContent):
|
||||
"""
|
||||
Broadcast a message to multiple rooms concurrently.
|
||||
:param room_ids:
|
||||
:param content:
|
||||
:return:
|
||||
"""
|
||||
self.log.info("Broadcasting message to %d rooms: %s",
|
||||
len(room_ids), content.body)
|
||||
|
||||
async def send(rid):
|
||||
try:
|
||||
await self.client.send_message(sub.room_id, content)
|
||||
await self.client.send_message(rid, content)
|
||||
except Exception as e:
|
||||
self.log.exception(
|
||||
"Failed to send matrix message to room with id '%s'!",
|
||||
rid, exc_info=e)
|
||||
|
||||
async with asyncio.TaskGroup() as task_group:
|
||||
for room_id in room_ids:
|
||||
task_group.create_task(send(room_id))
|
||||
|
||||
async def handle_subscription_error(self,
|
||||
subscription_error: SubscriptionError,
|
||||
topic: Topic,
|
||||
evt: MessageEvent,
|
||||
callbacks: List[Awaitable] = None) \
|
||||
-> None:
|
||||
"""
|
||||
Handles a subscription error by removing the topic and all related
|
||||
subscriptions from the database and notifying all subscribed rooms.
|
||||
:param subscription_error:
|
||||
:param topic:
|
||||
:param evt:
|
||||
:param callbacks: List of async functions to run afterward
|
||||
:return:
|
||||
"""
|
||||
task = self.tasks[topic.id]
|
||||
message = subscription_error.message
|
||||
code = subscription_error.code
|
||||
|
||||
self.log.error("ntfy server responded: '%s' %s",
|
||||
message if message else "unknown error",
|
||||
"(code %s)" % code if code else "")
|
||||
|
||||
# Get affected rooms
|
||||
subscriptions = await self.db.get_subscriptions_by_topic_id(topic.id)
|
||||
# Filter out the current room if there is one
|
||||
if evt:
|
||||
room_ids = [sub.room_id for sub in subscriptions
|
||||
if sub.room_id != evt.room_id]
|
||||
else:
|
||||
room_ids = [sub.room_id for sub in subscriptions]
|
||||
|
||||
match code:
|
||||
# Handle 'unauthorized' and 'limit reached: too many auth failures'
|
||||
case 40101 | 42909:
|
||||
|
||||
# Remove topic and all related subscriptions from database
|
||||
result = await self.db.remove_topic(topic.id)
|
||||
self.log.info("Removed topic: %s", topic.id)
|
||||
|
||||
# Notify all subscribed rooms that the subscription has
|
||||
# been cancelled
|
||||
notice = await build_notice(
|
||||
"<span>Authentication failed, please check the "
|
||||
"access token. Subscription to <code>'%s/%s'</code> has "
|
||||
"been cancelled and topic has been removed.</span>" %
|
||||
(escape(topic.topic), escape(topic.server)))
|
||||
case _:
|
||||
notice = await build_notice(
|
||||
"<span>Subscription to <code>'%s/%s'</code> failed: "
|
||||
"<code>'%s'</code></span>" %
|
||||
(escape(topic.topic), escape(topic.server),
|
||||
escape(message["error"])))
|
||||
|
||||
# Notify the user that the subscription has been cancelled
|
||||
try:
|
||||
await evt.reply(notice)
|
||||
except Exception as exc:
|
||||
self.log.exception(
|
||||
"Failed to send matrix message!", exc_info=exc)
|
||||
"Failed to send matrix message to room with id '%s'!",
|
||||
evt.room_id, exc_info=exc)
|
||||
# Broadcast the error message to all subscribed rooms
|
||||
await self.broadcast_to_rooms(room_ids, notice)
|
||||
|
||||
# Cancel the task
|
||||
if task.cancel():
|
||||
self.log.info("Subscription task cancelled")
|
||||
|
||||
# Run the callbacks concurrently
|
||||
if callbacks:
|
||||
async with asyncio.TaskGroup() as callback_tasks:
|
||||
for callback in callbacks:
|
||||
callback_tasks.create_task(callback)
|
||||
|
||||
def build_message_content(self, server: str, message) -> str:
|
||||
"""
|
||||
Build the message content for a ntfy message.
|
||||
:param server:
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
topic = message["topic"]
|
||||
body = message["message"]
|
||||
title = message.get("title", None)
|
||||
|
|
@ -201,33 +624,35 @@ class NtfyBot(Plugin):
|
|||
else:
|
||||
emoji = tags = ""
|
||||
|
||||
html_content = "<span>Ntfy message in topic <code>%s/%s</code></span><blockquote>" % (
|
||||
html.escape(server), html.escape(topic))
|
||||
html_content = ("<span>Ntfy message in topic <code>%s/%s</code></span>"
|
||||
"<blockquote>") % (
|
||||
escape(server), escape(topic))
|
||||
# build title
|
||||
if title and click:
|
||||
html_content += "<h4>%s<a href=\"%s\">%s</a></h4>" % (
|
||||
emoji, html.escape(click), html.escape(title))
|
||||
emoji, escape(click), escape(title))
|
||||
emoji = ""
|
||||
elif title:
|
||||
html_content += "<h4>%s%s</h4>" % (emoji, html.escape(title))
|
||||
html_content += "<h4>%s%s</h4>" % (emoji, escape(title))
|
||||
emoji = ""
|
||||
|
||||
# build body
|
||||
if click and not title:
|
||||
html_content += "%s<a href=\"%s\">%s</a>" % (
|
||||
emoji, html.escape(click), html.escape(body).replace("\n", "<br />"))
|
||||
emoji, escape(click),
|
||||
escape(body).replace("\n", "<br />"))
|
||||
else:
|
||||
html_content += emoji + html.escape(body).replace("\n", "<br />")
|
||||
html_content += emoji + escape(body).replace("\n", "<br />")
|
||||
|
||||
# add non-emoji tags
|
||||
if tags:
|
||||
html_content += "<br/><small>Tags: <code>%s</code></small>" % html.escape(
|
||||
html_content += "<br/><small>Tags: <code>%s</code></small>" % escape(
|
||||
tags)
|
||||
|
||||
# build attachment
|
||||
if attachment:
|
||||
html_content += "<br/><a href=\"%s\">View %s</a>" % (html.escape(
|
||||
attachment["url"]), html.escape(attachment["name"]))
|
||||
html_content += "<br/><a href=\"%s\">View %s</a>" % (escape(
|
||||
attachment["url"]), escape(attachment["name"]))
|
||||
html_content += "</blockquote>"
|
||||
|
||||
return html_content
|
||||
|
|
@ -239,3 +664,32 @@ class NtfyBot(Plugin):
|
|||
@classmethod
|
||||
def get_db_upgrade_table(cls) -> UpgradeTable | None:
|
||||
return upgrade_table
|
||||
|
||||
async def add_reaction(self, evt: MessageEvent, emoji: str):
|
||||
"""
|
||||
Add a reaction to a message and store the event id.
|
||||
"""
|
||||
self.log.debug("Adding reaction %s to event %s",
|
||||
emoji, evt.event_id)
|
||||
reaction_id = await evt.react(emoji)
|
||||
self.reactions.update(
|
||||
{evt.room_id: {evt.event_id: {emoji: reaction_id}}}
|
||||
)
|
||||
|
||||
async def remove_reaction(self, evt: MessageEvent, emoji: str) -> None:
|
||||
"""
|
||||
Remove a reaction from a message.
|
||||
:param evt: The original message event
|
||||
:param emoji: The emoji sent as a reaction
|
||||
"""
|
||||
self.log.debug("Removing reaction %s from event %s",
|
||||
emoji, evt.event_id)
|
||||
try:
|
||||
reaction_id = self.reactions[evt.room_id][evt.event_id][emoji]
|
||||
if reaction_id:
|
||||
result = await self.client.redact(evt.room_id, reaction_id)
|
||||
if result:
|
||||
self.reactions[evt.room_id][evt.event_id].pop(emoji)
|
||||
except KeyError:
|
||||
self.log.warning("Reaction %s not found in self.reactions[%s][%s]",
|
||||
emoji, evt.room_id, evt.event_id)
|
||||
|
|
|
|||
55
ntfy/db.py
55
ntfy/db.py
|
|
@ -37,13 +37,33 @@ async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
|
|||
)"""
|
||||
)
|
||||
|
||||
@upgrade_table.register(description="Cascade delete of topics")
|
||||
async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
|
||||
await conn.execute(
|
||||
"""CREATE TABLE subscriptions_new (
|
||||
topic_id INTEGER,
|
||||
room_id TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (topic_id, room_id),
|
||||
FOREIGN KEY (topic_id) REFERENCES topics (id) ON DELETE CASCADE
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""INSERT INTO subscriptions_new SELECT * FROM subscriptions"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""DROP TABLE subscriptions"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""ALTER TABLE subscriptions_new RENAME TO subscriptions"""
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Topic:
|
||||
id: int
|
||||
server: str
|
||||
topic: str
|
||||
last_event_id: str
|
||||
last_event_id: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
subscriptions: List[Subscription] = attr.ib(factory=lambda: [])
|
||||
|
|
@ -71,6 +91,8 @@ class Topic:
|
|||
class Subscription:
|
||||
topic_id: int
|
||||
room_id: RoomID
|
||||
server: str
|
||||
topic: str
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Record | None) -> Topic | None:
|
||||
|
|
@ -78,9 +100,13 @@ class Subscription:
|
|||
return None
|
||||
topic_id = row["topic_id"]
|
||||
room_id = row["room_id"]
|
||||
server = row["server"]
|
||||
topic = row["topic"]
|
||||
return cls(
|
||||
topic_id=topic_id,
|
||||
room_id=room_id,
|
||||
server=server,
|
||||
topic=topic,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -145,6 +171,13 @@ class DB:
|
|||
)
|
||||
return topic
|
||||
|
||||
async def remove_topic(self, topic_id: int):
|
||||
query = """
|
||||
DELETE FROM topics
|
||||
WHERE id = $1
|
||||
"""
|
||||
return await self.db.execute(query, topic_id)
|
||||
|
||||
async def get_topic(self, server: str, topic: str) -> Topic | None:
|
||||
query = """
|
||||
SELECT id, server, topic, token, last_event_id
|
||||
|
|
@ -165,10 +198,12 @@ class DB:
|
|||
row = await self.db.fetchrow(query, topic_id, room_id)
|
||||
return (Subscription.from_row(row), Topic.from_row(row))
|
||||
|
||||
async def get_subscriptions(self, topic_id: int) -> List[Subscription]:
|
||||
async def get_subscriptions_by_topic_id(self, topic_id: int) \
|
||||
-> List[Subscription]:
|
||||
query = """
|
||||
SELECT topic_id, room_id
|
||||
SELECT topic_id, room_id, server, topic
|
||||
FROM subscriptions
|
||||
INNER JOIN topics ON topics.id = subscriptions.topic_id
|
||||
WHERE topic_id = $1
|
||||
"""
|
||||
rows = await self.db.fetch(query, topic_id)
|
||||
|
|
@ -177,6 +212,20 @@ class DB:
|
|||
subscriptions.append(Subscription.from_row(row))
|
||||
return subscriptions
|
||||
|
||||
async def get_subscriptions_by_room_id(self, room_id: RoomID) \
|
||||
-> List[Subscription]:
|
||||
query = """
|
||||
SELECT topic_id, room_id, server, topic
|
||||
FROM subscriptions
|
||||
INNER JOIN topics ON topics.id = subscriptions.topic_id
|
||||
WHERE room_id = $1
|
||||
"""
|
||||
rows = await self.db.fetch(query, room_id)
|
||||
subscriptions = []
|
||||
for row in rows:
|
||||
subscriptions.append(Subscription.from_row(row))
|
||||
return subscriptions
|
||||
|
||||
async def add_subscription(self, topic_id: int, room_id: RoomID) -> None:
|
||||
query = """
|
||||
INSERT INTO subscriptions (topic_id, room_id)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ except ImportError:
|
|||
EMOJI_FALLBACK = True
|
||||
|
||||
WHITE_CHECK_MARK = emoji.emojize(":white_check_mark:", language="alias")
|
||||
REPEAT = emoji.emojize(":repeat:", language="alias")
|
||||
NO_ENTRY = emoji.emojize(":no_entry:", language="alias")
|
||||
WARNING = emoji.emojize(":warning:", language="alias")
|
||||
|
||||
|
||||
def parse_tags(log: TraceLogger, tags: List[str]) -> Tuple[List[str], List[str]]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue