diff --git a/README.md b/README.md index 4ddd8ce..e3caf0c 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,23 @@ # maubot-ntfy -This is a [maubot](https://maubot.xyz/) plugin to subscribe to [ntfy](https://ntfy.sh/) topics and send messages to a matrix room. +This is a [maubot](https://maubot.xyz/) plugin to subscribe +to [ntfy](https://ntfy.sh/) topics and send messages to a matrix room. ## Usage -Install as a maubot plugin and configure an instance. Until maubot supports installing python dependencies, you should install the `emoji` python package for full functionality. Alternatively, `@ntfy:catgirl.cloud` is available as well. +Install as a maubot plugin and configure an instance. Until maubot supports +installing python dependencies, you should install the `emoji` python package +for full functionality. Alternatively, `@ntfy:catgirl.cloud` is available as +well. -Use `!ntfy subscribe server/topic` (for example `!ntfy subscribe ntfy.sh/my_topic`) to subscribe the current room to the ntfy topic. Future messages will be sent to the room. +### Commands -To unsubscribe, use `!ntfy unsubscribe server/topic`. +| Command | Example | Description | +|---------------------------------------------------|-------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| +| `!ntfy add-topic /` | `!ntfy add-topic ntfy.sh/my_topic` | Adds a topic to the database and makes it available globally within this bot instance. | +| `!ntfy remove-topic /` | `!ntfy remove-topic ntfy.sh/my_topic` | Removes an available topic from the database and to cancels all associated subscriptions. | +| `!ntfy list-topics` | | Lists all available topics for this bot instance. | +| `!ntfy subscribe /` | `!ntfy subscribe ntfy.sh/my_topic` | Subscribe the current room to the ntfy topic. Future messages will be sent to the room. | +| `!ntfy subscribe / ` | `!ntfy subscribe ntfy.sh/my_protected_topic tk_be6uc2ca1x1orakcwg1j3hp0ylot6` | Subscribes to a protected topic via [access token](https://docs.ntfy.sh/config/#access-tokens). | +| `!ntfy unsubscribe server/topic` | `!ntfy subscribe ntfy.sh/my_topic` | Unsubscribes a room from a topic. | +| `!ntfy list-subscriptions` | | Lists all active subscriptions for this bot room. | diff --git a/maubot.yaml b/maubot.yaml index ffce0e8..fc30a89 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,6 +1,6 @@ maubot: 0.3.0 id: cloud.catgirl.ntfy -version: 0.1.0 +version: 0.2.0 license: AGPL-3.0-or-later modules: - ntfy diff --git a/ntfy/bot.py b/ntfy/bot.py index 057785a..918b9ab 100644 --- a/ntfy/bot.py +++ b/ntfy/bot.py @@ -1,33 +1,70 @@ import asyncio -import html +from html import escape import json -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, List, Awaitable, Callable from aiohttp import ClientTimeout from maubot import MessageEvent, Plugin from maubot.handlers import command from mautrix.types import (EventType, Format, MessageType, - TextMessageEventContent) + TextMessageEventContent, RoomID, EventID) from mautrix.util.async_db import UpgradeTable from mautrix.util.config import BaseProxyConfig from mautrix.util.formatter import parse_html from .config import Config from .db import DB, Topic, upgrade_table -from .emoji import EMOJI_FALLBACK, WHITE_CHECK_MARK, parse_tags +from .emoji import (EMOJI_FALLBACK, WHITE_CHECK_MARK, REPEAT, NO_ENTRY, + WARNING, parse_tags) +from .exceptions import SubscriptionError + + +async def build_notice(html_content) -> TextMessageEventContent: + """ + Build a notice message. + """ + text_content = await parse_html(html_content.strip()) + return TextMessageEventContent( + msgtype=MessageType.NOTICE, + format=Format.HTML, + formatted_body=html_content, + body=text_content, + ) + + +def ensure_permission(func: Callable): + """ + Decorator function to ensure that the user has the required permission + to execute a command. + """ + + async def wrapper(self, *args, **kwargs): + evt = args[0] + if evt.sender in self.config["admins"]: + return await func(self, *args, **kwargs) + levels = await self.client.get_state_event(evt.room_id, + EventType.ROOM_POWER_LEVELS) + user_level = levels.get_user_level(evt.sender) + notice = await build_notice( + "You don't have the permission to manage the ntfy bot.") + await evt.reply(notice) + return + + return wrapper class NtfyBot(Plugin): db: DB config: Config tasks: Dict[int, asyncio.Task] = {} + reactions: Dict[RoomID, Dict[str, Dict[str, EventID]]] = {} async def start(self) -> None: await super().start() self.config.load_and_update() self.db = DB(self.database, self.log) if EMOJI_FALLBACK: - self.log.warn( + self.log.warning( "Please install the `emoji` package for full emoji support") await self.resubscribe() @@ -40,150 +77,538 @@ class NtfyBot(Plugin): self.config.load_and_update() async def resubscribe(self) -> None: + """ + Clear all subscriptions and resubscribe to all topics. + :return: + """ await self.clear_subscriptions() await self.subscribe_to_topics() async def clear_subscriptions(self) -> None: - tasks = list(self.tasks.values()) - if not tasks: - return None + """ + Cancel all subscription tasks and clear the task list. + :return: + """ + for topic_id in self.tasks.keys(): + await self.cancel_topic_subscription_task(topic_id) - for task in tasks: - if not task.done(): - self.log.debug("cancelling subscription task...") - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - except Exception as exc: - self.log.exception("Subscription task errored", exc_info=exc) - self.tasks.clear() + async def cancel_topic_subscription_task(self, topic_id: int): + """ + Cancel a topic subscription task and handle any exceptions. + :param topic_id: The topic id to cancel the subscription task for + :return: + """ + task = self.tasks.get(topic_id) + if task and not task.done(): + self.log.debug("Cancelling subscription task for topic %d", + topic_id) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception as exc: + self.log.exception("Subscription task errored", exc_info=exc) - async def can_use_command(self, evt: MessageEvent) -> bool: - if evt.sender in self.config["admins"]: - return True - levels = await self.client.get_state_event(evt.room_id, EventType.ROOM_POWER_LEVELS) - user_level = levels.get_user_level(evt.sender) - if user_level < 50: - await evt.reply("You don't have the permission to manage ntfy subscriptions in this room.") - return False - return True - - @command.new(name=lambda self: self.config["command_prefix"], help="Manage ntfy subscriptions.", require_subcommand=True) + @command.new(name=lambda self: self.config["command_prefix"], + help="Manage ntfy subscriptions.", require_subcommand=True) async def ntfy(self) -> None: pass - @ntfy.subcommand("subscribe", aliases=("sub",), help="Subscribe this room to a ntfy topic.") - @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") - async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None: - # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex - if not await self.can_use_command(evt): - return None + @ntfy.subcommand("list-topics", aliases=("topics",), + help="List all ntfy topics in the database.") + @ensure_permission + async def list_topics(self, evt: MessageEvent) -> None: + """ + List all topics in the database. + :param evt: The message event + :return: None + """ + # Get all topics from the database + topics = await self.db.get_topics() + + # Build the message content + if topics: + html_content = "Currently available ntfy topics:
    " + for topic in topics: + html_content += "
  • %s/%s
  • " % ( + escape(topic.server), escape(topic.topic)) + html_content += "
" + else: + html_content = "No topics available yet." + + # Send the message to the user + content = await build_notice(html_content) + await evt.reply(content) + + @ntfy.subcommand("add-topic", aliases=("add",), + help="Add a new topic to the database.") + @command.argument("topic", "topic URL", + matches="(([a-zA-Z0-9-]{1,63}\\.)+" + "[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") + @command.argument("token", "access token", required=False, + matches="tk_[a-zA-Z0-9]{29}") + @ensure_permission + async def add_topic(self, evt: MessageEvent, topic: Tuple[str, Any], + token: Tuple[str, Any]) -> None: + """ + Add a new topic to the database. + :param evt: The message event + :param topic: The topic URL (e.g. `ntfy.sh/my_topic`) + :param token: The access token for the topic + (e.g. `tk_be6uc2ca1x1orakcwg1j3hp0ylot6`) + :return: None + """ + # Look up the topic in the database + server, topic = topic[0].split("/") + token = token if token else None + db_topic = await self.db.get_topic(server, topic) + + # Check if the topic already exists + if db_topic: + # If it already exists, check if the token is the same + if db_topic.token == token: + await self.add_reaction(evt, NO_ENTRY) + notice = await build_notice( + "The topic %s/%s already exists." % + (escape(server), escape(topic))) + await evt.reply(notice) + return + # Else remove the topic and delete all corresponding subscriptions + else: + subscriptions = await self.db.get_subscriptions_by_topic_id( + db_topic.id) + await self.cancel_topic_subscription_task(db_topic.id) + await self.db.remove_topic(db_topic.id) + db_topic = None + await self.add_reaction(evt, WARNING) + notice = await build_notice( + "The topic %s/%s already exists " + "with a different token. The old entry with the different " + "token and any associated subscriptions have been removed " + "and would need to be re-subscribed." % + (escape(server), escape(topic))) + # Inform user about the result + await evt.reply(notice) + # Broadcast a notice to all subscribed rooms + broadcast_notice = await build_notice( + "The topic %s/%s has been updated with a new " + "access token. Please re-subscribe to this topic." % + (escape(server), escape(topic))) + room_ids = [sub.room_id for sub in subscriptions + if sub.room_id != evt.room_id] + await self.broadcast_to_rooms(room_ids, broadcast_notice) + + # Create the topic if it doesn't exist + db_topic = await self.db.create_topic( + Topic(id=-1, server=server, topic=topic, token=token, + last_event_id=None)) + + # Mark the command as successful + await self.add_reaction(evt, WHITE_CHECK_MARK) + + @ntfy.subcommand("remove-topic", aliases=("remove",), + help="Remove a topic from the database.") + @command.argument("topic", "topic URL", + matches="(([a-zA-Z0-9-]{1,63}\\.)+" + "[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") + @ensure_permission + async def remove_topic(self, evt: MessageEvent, + topic: Tuple[str, Any]) -> None: + """ + Remove a topic from the database. + :param evt: The message event + :param topic: The topic URL (e.g. `ntfy.sh/my_topic`) + :return: None + """ + # Lookup topic from db server, topic = topic[0].split("/") db_topic = await self.db.get_topic(server, topic) + + # Delete the topic if it exists and cancel all corresponding + # subscriptions + if db_topic: + subscriptions = await self.db.get_subscriptions_by_topic_id( + db_topic.id) + await self.cancel_topic_subscription_task(db_topic.id) + await self.db.remove_topic(db_topic.id) + await self.add_reaction(evt, WHITE_CHECK_MARK) + notice = await build_notice( + "The topic %s/%s has been deleted from the " + "database and all associated subscriptions have been " + "cancelled." % (escape(server), escape(topic))) + # Broadcast a notice to all subscribed rooms + broadcast_notice = await build_notice( + "The topic %s/%s has been removed from the " + "database. The room has been unsubscribed from this topic." % + (escape(server), escape(topic))) + room_ids = [sub.room_id for sub in subscriptions + if sub.room_id != evt.room_id] + await self.broadcast_to_rooms(room_ids, broadcast_notice) + + # Inform the user if the topic doesn't exist + else: + await self.add_reaction(evt, NO_ENTRY) + notice = await build_notice( + "The topic %s/%s does not exist in the database." % + (escape(server), escape(topic))) + + # Inform user about the result + await evt.reply(notice) + + @ntfy.subcommand("list-subscriptions", aliases=("subscriptions",), + help="List all active subscriptions for this room.") + @ensure_permission + async def list_subscriptions(self, evt: MessageEvent) -> None: + """ + List all active subscriptions for this room. + :param evt: The message event + :return: None + """ + # Get all subscriptions for this room from the database + subscriptions = await self.db.get_subscriptions_by_room_id(evt.room_id) + + # Build the message content + html_content = ("This room is currently subscribed to these " + "topics:
    ") + for sub in subscriptions: + html_content += "
  • %s/%s
  • " % ( + escape(sub.server), escape(sub.topic)) + html_content += "
" + + # Send the message to the user + content = await build_notice(html_content) + await evt.reply(content) + + @ntfy.subcommand("subscribe", aliases=("sub",), + help="Subscribe this room to a ntfy topic.") + @command.argument("topic", "topic URL", + matches="(([a-zA-Z0-9-]{1,63}\\.)+" + "[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") + @ensure_permission + async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]): + """ + Subscribe this room to a ntfy topic. + See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 + for the valid topic regex. + :param evt: + :param topic: The topic URL (e.g. `ntfy.sh/my_topic`) + :return: None + """ + # Check if the topic already exists + server, topic = topic[0].split("/") + db_topic = await self.db.get_topic(server, topic) + + # If it doesn't exist, tell user to add it first if not db_topic: - db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, last_event_id=None)) - existing_subscriptions = await self.db.get_subscriptions(db_topic.id) + await self.add_reaction(evt, NO_ENTRY) + notice = await build_notice( + "The topic %s/%s does not exist in the" + "database. Please add it first: %s add-subscription " + "\\ [access token]" % + (escape(server), escape(topic), self.config["command_prefix"])) + await evt.reply(notice) + return + + # Check if the room is already subscribed to the topic + existing_subscriptions = await self.db.get_subscriptions_by_topic_id( + db_topic.id) sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id) if sub: - await evt.reply("This room is already subscribed to %s/%s", server, topic) + notice = await build_notice( + "This room is already subscribed to %s/%s." % + (escape(server), escape(topic))) + await evt.reply(notice) + + # Subscribe the room to the topic if it isn't already subscribed else: await self.db.add_subscription(db_topic.id, evt.room_id) - await evt.reply("Subscribed this room to %s/%s", server, topic) - await evt.react(WHITE_CHECK_MARK) if not existing_subscriptions: - self.subscribe_to_topic(db_topic) + self.subscribe_to_topic(db_topic, evt) - @ntfy.subcommand("unsubscribe", aliases=("unsub",), help="Unsubscribe this room from a ntfy topic.") - @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") - async def unsubscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None: - # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex - if not await self.can_use_command(evt): - return None + # Mark the command as successful + await self.add_reaction(evt, WHITE_CHECK_MARK) + + @ntfy.subcommand("unsubscribe", aliases=("unsub",), + help="Unsubscribe this room from a ntfy topic.") + @command.argument("topic", "topic URL", + matches="(([a-zA-Z0-9-]{1,63}\\.)+" + "[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") + @ensure_permission + async def unsubscribe(self, evt: MessageEvent, + topic: Tuple[str, Any]) -> None: + """ + Unsubscribe this room from a ntfy topic. + See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex. + :param evt: + :param topic: The topic URL (e.g. `ntfy.sh/my_topic`) + :return: + """ + # Check if the topic exists server, topic = topic[0].split("/") db_topic = await self.db.get_topic(server, topic) + notice = await build_notice( + "This room is not subscribed to %s/%s." % + (escape(db_topic.server), escape(db_topic.topic))) if not db_topic: - await evt.reply("This room is not subscribed to %s/%s", server, topic) + await evt.reply(notice) return sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id) if not sub: - await evt.reply("This room is not subscribed to %s/%s", server, topic) + await evt.reply(notice) return + + subscriptions = await self.db.get_subscriptions_by_topic_id(db_topic.id) + + # Unsubscribe the room from the topic await self.db.remove_subscription(db_topic.id, evt.room_id) - if not await self.db.get_subscriptions(db_topic.id): - self.tasks[db_topic.id].cancel() + + # Cancel the subscription task if there are no more subscriptions + if len(subscriptions) <= 1: + await self.cancel_topic_subscription_task(db_topic.id) + + # Clear the last event id if there are no more subscriptions + if not await self.db.get_subscriptions_by_topic_id(db_topic.id): await self.db.clear_topic_id(db_topic.id) - await evt.reply("Unsubscribed this room from %s/%s", server, topic) - await evt.react(WHITE_CHECK_MARK) + + # Notify the user that the room has been unsubscribed + notice = await build_notice( + "Unsubscribed this room from %s/%s." % + (escape(db_topic.server), escape(db_topic.topic))) + await evt.reply(notice) + await self.add_reaction(evt, WHITE_CHECK_MARK) async def subscribe_to_topics(self) -> None: + """ + Subscribe to all topics in the database. + :return: + """ topics = await self.db.get_topics() for topic in topics: self.subscribe_to_topic(topic) - def subscribe_to_topic(self, topic: Topic) -> None: - def log_task_exc(task: asyncio.Task) -> None: - t2 = self.tasks.pop(topic.id, None) - if t2 != task: - self.log.warn("stored task doesn't match callback") - if task.done() and not task.cancelled(): - exc = task.exception() - self.log.exception( - "Subscription task errored, resubscribing", exc_info=exc) - self.tasks[topic.id] = self.loop.create_task( - asyncio.sleep(10.0)) - self.tasks[topic.id].add_done_callback( - lambda _: self.subscribe_to_topic(topic)) + def subscribe_to_topic(self, topic: Topic, + evt: MessageEvent | None = None) -> None: + """ + Subscribe to a topic. + :param topic: + :param evt: + :return: + """ - self.log.info("Subscribing to %s/%s", topic.server, topic.topic) + # Create a callback to log the exception if the task fails + def log_task_exc(t: asyncio.Task) -> None: + t2 = self.tasks.pop(topic.id, None) + if t2 != t: + self.log.warning("Stored task doesn't match callback") + # Log the exception if the task failed + if t.done() and not t.cancelled(): + exc = t.exception() + # Handle subscription errors + if isinstance(exc, SubscriptionError): + self.log.exception("Failed to subscribe to %s/%s", + topic.server, topic.topic) + self.tasks[topic.id] = self.loop.create_task( + self.handle_subscription_error(exc, topic, evt, [ + self.remove_reaction(evt, WHITE_CHECK_MARK), + self.add_reaction(evt, NO_ENTRY), + ]) + ) + else: + # Try to resubscribe if the task errored + if evt: + self.loop.create_task(self.add_reaction(evt, REPEAT)) + self.log.exception( + "Subscription task errored, resubscribing", + exc_info=exc) + # Sleep for 10 seconds before resubscribing + self.tasks[topic.id] = self.loop.create_task( + asyncio.sleep(10.0)) + # Subscribe to the topic again + self.tasks[topic.id].add_done_callback( + lambda _: self.subscribe_to_topic(topic, evt)) + + # Prepare the URL for the topic + self.log.info("Subscribing to %s/%s" % (topic.server, topic.topic)) url = "%s/%s/json" % (topic.server, topic.topic) + # Prepend the URL with `https://` if necessary if not url.startswith(("http://", "https://")): url = "https://" + url + # Add the last event id to the URL if it exists if topic.last_event_id: url += "?since=%s" % topic.last_event_id + # Subscribe to the topic self.log.debug("Subscribing to URL %s", url) task = self.loop.create_task( - self.run_topic_subscription(topic, url)) + self.run_topic_subscription(topic, url, evt)) self.tasks[topic.id] = task task.add_done_callback(log_task_exc) - async def run_topic_subscription(self, topic: Topic, url: str) -> None: - async with self.http.get(url, timeout=ClientTimeout()) as resp: + async def run_topic_subscription(self, topic: Topic, url: str, + evt: MessageEvent | None = None) -> None: + """ + Subscribe to a topic. + :param topic: + :param url: + :param evt: + :return: + """ + # Prepare authentication headers in case a token is provided + headers = { + "Authorization": f"Bearer {topic.token}"} if topic.token else None + + # Subscribe to the topic + async with self.http.get(url, timeout=ClientTimeout(), + headers=headers) as resp: + + # Loop through the response content while True: line = await resp.content.readline() - # convert to string and remove trailing newline + + # Convert to string and remove trailing newline line = line.decode("utf-8").strip() + + # Break if the line is empty + if not line: + break + + # Parse the line as JSON self.log.trace("Received notification: %s", line) message = json.loads(line) - if message["event"] != "message": + + # Check if the message is an event + if "event" not in message or message["event"] != "message": + # Check if the message is an error, else continue + if "error" in message: + raise SubscriptionError(message) continue self.log.debug("Received message event: %s", line) - # persist the received message id + + # Persist the received last event id await self.db.update_topic_id(topic.id, message["id"]) - # build matrix message + # Build matrix message html_content = self.build_message_content( topic.server, message) text_content = await parse_html(html_content.strip()) - content = TextMessageEventContent( - msgtype=MessageType.TEXT, + msgtype=MessageType.NOTICE, format=Format.HTML, formatted_body=html_content, body=text_content, ) - subscriptions = await self.db.get_subscriptions(topic.id) - for sub in subscriptions: - try: - await self.client.send_message(sub.room_id, content) - except Exception as exc: - self.log.exception( - "Failed to send matrix message!", exc_info=exc) + # Broadcast the message to all subscribed rooms + subscriptions = await self.db.get_subscriptions_by_topic_id( + topic.id) + room_ids = [sub.room_id for sub in subscriptions] + await self.broadcast_to_rooms(room_ids, content) + + async def broadcast_to_rooms(self, room_ids: List[RoomID], + content: TextMessageEventContent): + """ + Broadcast a message to multiple rooms concurrently. + :param room_ids: + :param content: + :return: + """ + self.log.info("Broadcasting message to %d rooms: %s", + len(room_ids), content.body) + + async def send(rid): + try: + await self.client.send_message(rid, content) + except Exception as e: + self.log.exception( + "Failed to send matrix message to room with id '%s'!", + rid, exc_info=e) + + async with asyncio.TaskGroup() as task_group: + for room_id in room_ids: + task_group.create_task(send(room_id)) + + async def handle_subscription_error(self, + subscription_error: SubscriptionError, + topic: Topic, + evt: MessageEvent, + callbacks: List[Awaitable] = None) \ + -> None: + """ + Handles a subscription error by removing the topic and all related + subscriptions from the database and notifying all subscribed rooms. + :param subscription_error: + :param topic: + :param evt: + :param callbacks: List of async functions to run afterward + :return: + """ + task = self.tasks[topic.id] + message = subscription_error.message + code = subscription_error.code + + self.log.error("ntfy server responded: '%s' %s", + message if message else "unknown error", + "(code %s)" % code if code else "") + + # Get affected rooms + subscriptions = await self.db.get_subscriptions_by_topic_id(topic.id) + # Filter out the current room if there is one + if evt: + room_ids = [sub.room_id for sub in subscriptions + if sub.room_id != evt.room_id] + else: + room_ids = [sub.room_id for sub in subscriptions] + + match code: + # Handle 'unauthorized' and 'limit reached: too many auth failures' + case 40101 | 42909: + + # Remove topic and all related subscriptions from database + result = await self.db.remove_topic(topic.id) + self.log.info("Removed topic: %s", topic.id) + + # Notify all subscribed rooms that the subscription has + # been cancelled + notice = await build_notice( + "Authentication failed, please check the " + "access token. Subscription to '%s/%s' has " + "been cancelled and topic has been removed." % + (escape(topic.topic), escape(topic.server))) + case _: + notice = await build_notice( + "Subscription to '%s/%s' failed: " + "'%s'" % + (escape(topic.topic), escape(topic.server), + escape(message["error"]))) + + # Notify the user that the subscription has been cancelled + try: + await evt.reply(notice) + except Exception as exc: + self.log.exception( + "Failed to send matrix message to room with id '%s'!", + evt.room_id, exc_info=exc) + # Broadcast the error message to all subscribed rooms + await self.broadcast_to_rooms(room_ids, notice) + + # Cancel the task + if task.cancel(): + self.log.info("Subscription task cancelled") + + # Run the callbacks concurrently + if callbacks: + async with asyncio.TaskGroup() as callback_tasks: + for callback in callbacks: + callback_tasks.create_task(callback) def build_message_content(self, server: str, message) -> str: + """ + Build the message content for a ntfy message. + :param server: + :param message: + :return: + """ topic = message["topic"] body = message["message"] title = message.get("title", None) @@ -198,33 +623,35 @@ class NtfyBot(Plugin): else: emoji = tags = "" - html_content = "Ntfy message in topic %s/%s
" % ( - html.escape(server), html.escape(topic)) + html_content = ("Ntfy message in topic %s/%s" + "
") % ( + escape(server), escape(topic)) # build title if title and click: html_content += "

%s%s

" % ( - emoji, html.escape(click), html.escape(title)) + emoji, escape(click), escape(title)) emoji = "" elif title: - html_content += "

%s%s

" % (emoji, html.escape(title)) + html_content += "

%s%s

" % (emoji, escape(title)) emoji = "" # build body if click and not title: html_content += "%s%s" % ( - emoji, html.escape(click), html.escape(body).replace("\n", "
")) + emoji, escape(click), + escape(body).replace("\n", "
")) else: - html_content += emoji + html.escape(body).replace("\n", "
") + html_content += emoji + escape(body).replace("\n", "
") # add non-emoji tags if tags: - html_content += "
Tags: %s" % html.escape( + html_content += "
Tags: %s" % escape( tags) # build attachment if attachment: - html_content += "
View %s" % (html.escape( - attachment["url"]), html.escape(attachment["name"])) + html_content += "
View %s" % (escape( + attachment["url"]), escape(attachment["name"])) html_content += "
" return html_content @@ -236,3 +663,32 @@ class NtfyBot(Plugin): @classmethod def get_db_upgrade_table(cls) -> UpgradeTable | None: return upgrade_table + + async def add_reaction(self, evt: MessageEvent, emoji: str): + """ + Add a reaction to a message and store the event id. + """ + self.log.debug("Adding reaction %s to event %s", + emoji, evt.event_id) + reaction_id = await evt.react(emoji) + self.reactions.update( + {evt.room_id: {evt.event_id: {emoji: reaction_id}}} + ) + + async def remove_reaction(self, evt: MessageEvent, emoji: str) -> None: + """ + Remove a reaction from a message. + :param evt: The original message event + :param emoji: The emoji sent as a reaction + """ + self.log.debug("Removing reaction %s from event %s", + emoji, evt.event_id) + try: + reaction_id = self.reactions[evt.room_id][evt.event_id][emoji] + if reaction_id: + result = await self.client.redact(evt.room_id, reaction_id) + if result: + self.reactions[evt.room_id][evt.event_id].pop(emoji) + except KeyError: + self.log.warning("Reaction %s not found in self.reactions[%s][%s]", + emoji, evt.room_id, evt.event_id) diff --git a/ntfy/db.py b/ntfy/db.py index 7b0c36b..e40579e 100644 --- a/ntfy/db.py +++ b/ntfy/db.py @@ -20,6 +20,7 @@ async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: id INTEGER {gen}, server TEXT NOT NULL, topic TEXT NOT NULL, + token TEXT, last_event_id TEXT, PRIMARY KEY (id), @@ -36,13 +37,34 @@ async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: )""" ) +@upgrade_table.register(description="Cascade delete of topics") +async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: + await conn.execute( + """CREATE TABLE subscriptions_new ( + topic_id INTEGER, + room_id TEXT NOT NULL, + + PRIMARY KEY (topic_id, room_id), + FOREIGN KEY (topic_id) REFERENCES topics (id) ON DELETE CASCADE + )""" + ) + await conn.execute( + """INSERT INTO subscriptions_new SELECT * FROM subscriptions""" + ) + await conn.execute( + """DROP TABLE subscriptions""" + ) + await conn.execute( + """ALTER TABLE subscriptions_new RENAME TO subscriptions""" + ) @dataclass class Topic: id: int server: str topic: str - last_event_id: str + last_event_id: str | None = None + token: str | None = None subscriptions: List[Subscription] = attr.ib(factory=lambda: []) @@ -53,11 +75,13 @@ class Topic: id = row["id"] server = row["server"] topic = row["topic"] + token = row["token"] last_event_id = row["last_event_id"] return cls( id=id, server=server, topic=topic, + token=token, last_event_id=last_event_id, subscriptions=[] ) @@ -67,6 +91,8 @@ class Topic: class Subscription: topic_id: int room_id: RoomID + server: str + topic: str @classmethod def from_row(cls, row: Record | None) -> Topic | None: @@ -74,9 +100,13 @@ class Subscription: return None topic_id = row["topic_id"] room_id = row["room_id"] + server = row["server"] + topic = row["topic"] return cls( topic_id=topic_id, room_id=room_id, + server=server, + topic=topic, ) @@ -90,7 +120,7 @@ class DB: async def get_topics(self) -> List[Topic]: query = """ - SELECT id, server, topic, last_event_id, topic_id, room_id + SELECT id, server, topic, token, last_event_id, topic_id, room_id FROM topics INNER JOIN subscriptions ON topics.id = subscriptions.topic_id @@ -119,14 +149,15 @@ class DB: async def create_topic(self, topic: Topic) -> Topic: query = """ - INSERT INTO topics (server, topic, last_event_id) - VALUES ($1, $2, $3) RETURNING (id) + INSERT INTO topics (server, topic, token, last_event_id) + VALUES ($1, $2, $3, $4) RETURNING (id) """ if self.db.scheme == Scheme.SQLITE: cur = await self.db.execute( query.replace("RETURNING (id)", ""), topic.server, topic.topic, + topic.token, topic.last_event_id, ) topic.id = cur.lastrowid @@ -135,21 +166,30 @@ class DB: query, topic.server, topic.topic, + topic.token, topic.last_event_id, ) return topic + async def remove_topic(self, topic_id: int): + query = """ + DELETE FROM topics + WHERE id = $1 + """ + return await self.db.execute(query, topic_id) + async def get_topic(self, server: str, topic: str) -> Topic | None: query = """ - SELECT id, server, topic, last_event_id + SELECT id, server, topic, token, last_event_id FROM topics WHERE server = $1 AND topic = $2 """ return Topic.from_row(await self.db.fetchrow(query, server, topic)) - async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[Subscription | None, Topic | None]: + async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[ + Subscription | None, Topic | None]: query = """ - SELECT id, server, topic, last_event_id, topic_id, room_id + SELECT id, server, topic, token, last_event_id, topic_id, room_id FROM topics INNER JOIN subscriptions ON topics.id = subscriptions.topic_id AND subscriptions.room_id = $2 @@ -158,10 +198,12 @@ class DB: row = await self.db.fetchrow(query, topic_id, room_id) return (Subscription.from_row(row), Topic.from_row(row)) - async def get_subscriptions(self, topic_id: int) -> List[Subscription]: + async def get_subscriptions_by_topic_id(self, topic_id: int) \ + -> List[Subscription]: query = """ - SELECT topic_id, room_id + SELECT topic_id, room_id, server, topic FROM subscriptions + INNER JOIN topics ON topics.id = subscriptions.topic_id WHERE topic_id = $1 """ rows = await self.db.fetch(query, topic_id) @@ -170,6 +212,20 @@ class DB: subscriptions.append(Subscription.from_row(row)) return subscriptions + async def get_subscriptions_by_room_id(self, room_id: RoomID) \ + -> List[Subscription]: + query = """ + SELECT topic_id, room_id, server, topic + FROM subscriptions + INNER JOIN topics ON topics.id = subscriptions.topic_id + WHERE room_id = $1 + """ + rows = await self.db.fetch(query, room_id) + subscriptions = [] + for row in rows: + subscriptions.append(Subscription.from_row(row)) + return subscriptions + async def add_subscription(self, topic_id: int, room_id: RoomID) -> None: query = """ INSERT INTO subscriptions (topic_id, room_id) diff --git a/ntfy/emoji.py b/ntfy/emoji.py index f938abe..f796cfe 100644 --- a/ntfy/emoji.py +++ b/ntfy/emoji.py @@ -32,6 +32,9 @@ except ImportError: EMOJI_FALLBACK = True WHITE_CHECK_MARK = emoji.emojize(":white_check_mark:", language="alias") +REPEAT = emoji.emojize(":repeat:", language="alias") +NO_ENTRY = emoji.emojize(":no_entry:", language="alias") +WARNING = emoji.emojize(":warning:", language="alias") def parse_tags(log: TraceLogger, tags: List[str]) -> Tuple[List[str], List[str]]: