diff --git a/ntfy/bot.py b/ntfy/bot.py index c5681ea..a019b02 100644 --- a/ntfy/bot.py +++ b/ntfy/bot.py @@ -1,33 +1,70 @@ import asyncio -import html +from html import escape import json -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, List, Awaitable, Callable from aiohttp import ClientTimeout from maubot import MessageEvent, Plugin from maubot.handlers import command from mautrix.types import (EventType, Format, MessageType, - TextMessageEventContent) + TextMessageEventContent, RoomID, EventID) from mautrix.util.async_db import UpgradeTable from mautrix.util.config import BaseProxyConfig from mautrix.util.formatter import parse_html from .config import Config from .db import DB, Topic, upgrade_table -from .emoji import EMOJI_FALLBACK, WHITE_CHECK_MARK, parse_tags +from .emoji import (EMOJI_FALLBACK, WHITE_CHECK_MARK, REPEAT, NO_ENTRY, + WARNING, parse_tags) +from .exceptions import SubscriptionError + + +async def build_notice(html_content) -> TextMessageEventContent: + """ + Build a notice message. + """ + text_content = await parse_html(html_content.strip()) + return TextMessageEventContent( + msgtype=MessageType.NOTICE, + format=Format.HTML, + formatted_body=html_content, + body=text_content, + ) + + +def ensure_permission(func: Callable): + """ + Decorator function to ensure that the user has the required permission + to execute a command. + """ + + async def wrapper(self, *args, **kwargs): + evt = args[0] + if evt.sender in self.config["admins"]: + return await func(self, *args, **kwargs) + levels = await self.client.get_state_event(evt.room_id, + EventType.ROOM_POWER_LEVELS) + user_level = levels.get_user_level(evt.sender) + notice = await build_notice( + "You don't have the permission to manage the ntfy bot.") + await evt.reply(notice) + return + + return wrapper class NtfyBot(Plugin): db: DB config: Config tasks: Dict[int, asyncio.Task] = {} + reactions: Dict[RoomID, Dict[str, Dict[str, EventID]]] = {} async def start(self) -> None: await super().start() self.config.load_and_update() self.db = DB(self.database, self.log) if EMOJI_FALLBACK: - self.log.warn( + self.log.warning( "Please install the `emoji` package for full emoji support") await self.resubscribe() @@ -40,26 +77,30 @@ class NtfyBot(Plugin): self.config.load_and_update() async def resubscribe(self) -> None: + """ + Clear all subscriptions and resubscribe to all topics. + :return: + """ await self.clear_subscriptions() await self.subscribe_to_topics() async def clear_subscriptions(self) -> None: - tasks = list(self.tasks.values()) - if not tasks: - return None - - for task in tasks: - if not task.done(): - self.log.debug("cancelling subscription task...") - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - except Exception as exc: - self.log.exception("Subscription task errored", exc_info=exc) - self.tasks.clear() + """ + Cancel all subscription tasks and clear the task list. + :return: + """ + for topic_id in self.tasks.keys(): + await self.cancel_topic_subscription_task(topic_id) + async def cancel_topic_subscription_task(self, topic_id: int): + """ + Cancel a topic subscription task and handle any exceptions. + :param topic_id: The topic id to cancel the subscription task for + :return: + """ + task = self.tasks.get(topic_id) + if task and not task.done(): + self.log.debug("Cancelling subscription task for topic %d", async def can_use_command(self, evt: MessageEvent) -> bool: if evt.sender in self.config["admins"]: return True @@ -70,123 +111,505 @@ class NtfyBot(Plugin): return False return True - @command.new(name=lambda self: self.config["command_prefix"], help="Manage ntfy subscriptions.", require_subcommand=True) + @command.new(name=lambda self: self.config["command_prefix"], + help="Manage ntfy subscriptions.", require_subcommand=True) async def ntfy(self) -> None: pass - @ntfy.subcommand("subscribe", aliases=("sub",), help="Subscribe this room to a ntfy topic.") - @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") - @command.argument("token", "access token", required=False, matches="tk_[a-zA-Z0-9]{29}") - async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any], token: Tuple[str, Any]) -> None: - # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex - if not await self.can_use_command(evt): - return None + @ntfy.subcommand("list-topics", aliases=("topics",), + help="List all ntfy topics in the database.") + @ensure_permission + async def list_topics(self, evt: MessageEvent) -> None: + """ + List all topics in the database. + :param evt: The message event + :return: None + """ + # Get all topics from the database + topics = await self.db.get_topics() + + # Build the message content + if topics: + html_content = "Currently available ntfy topics:
%s/%s%s/%s already exists." %
+ (escape(server), escape(topic)))
+ await evt.reply(notice)
+ return
+ # Else remove the topic and delete all corresponding subscriptions
+ else:
+ subscriptions = await self.db.get_subscriptions_by_topic_id(
+ db_topic.id)
+ await self.cancel_topic_subscription_task(db_topic.id)
+ await self.db.remove_topic(db_topic.id)
+ db_topic = None
+ await self.add_reaction(evt, WARNING)
+ notice = await build_notice(
+ "The topic %s/%s already exists "
+ "with a different token. The old entry with the different "
+ "token and any associated subscriptions have been removed "
+ "and would need to be re-subscribed." %
+ (escape(server), escape(topic)))
+ # Inform user about the result
+ await evt.reply(notice)
+ # Broadcast a notice to all subscribed rooms
+ broadcast_notice = await build_notice(
+ "The topic %s/%s has been updated with a new "
+ "access token. Please re-subscribe to this topic." %
+ (escape(server), escape(topic)))
+ room_ids = [sub.room_id for sub in subscriptions
+ if sub.room_id != evt.room_id]
+ await self.broadcast_to_rooms(room_ids, broadcast_notice)
+
+ # Create the topic if it doesn't exist
+ db_topic = await self.db.create_topic(
+ Topic(id=-1, server=server, topic=topic, token=token,
+ last_event_id=None))
+
+ # Mark the command as successful
+ await self.add_reaction(evt, WHITE_CHECK_MARK)
+
+ @ntfy.subcommand("remove-topic", aliases=("remove",),
+ help="Remove a topic from the database.")
+ @command.argument("topic", "topic URL",
+ matches="(([a-zA-Z0-9-]{1,63}\\.)+"
+ "[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
+ @ensure_permission
+ async def remove_topic(self, evt: MessageEvent,
+ topic: Tuple[str, Any]) -> None:
+ """
+ Remove a topic from the database.
+ :param evt: The message event
+ :param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
+ :return: None
+ """
+ # Lookup topic from db
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
+
+ # Delete the topic if it exists and cancel all corresponding
+ # subscriptions
+ if db_topic:
+ subscriptions = await self.db.get_subscriptions_by_topic_id(
+ db_topic.id)
+ await self.cancel_topic_subscription_task(db_topic.id)
+ await self.db.remove_topic(db_topic.id)
+ await self.add_reaction(evt, WHITE_CHECK_MARK)
+ notice = await build_notice(
+ "The topic %s/%s has been deleted from the "
+ "database and all associated subscriptions have been "
+ "cancelled." % (escape(server), escape(topic)))
+ # Broadcast a notice to all subscribed rooms
+ broadcast_notice = await build_notice(
+ "The topic %s/%s has been removed from the "
+ "database. The room has been unsubscribed from this topic." %
+ (escape(server), escape(topic)))
+ room_ids = [sub.room_id for sub in subscriptions
+ if sub.room_id != evt.room_id]
+ await self.broadcast_to_rooms(room_ids, broadcast_notice)
+
+ # Inform the user if the topic doesn't exist
+ else:
+ await self.add_reaction(evt, NO_ENTRY)
+ notice = await build_notice(
+ "The topic %s/%s does not exist in the database." %
+ (escape(server), escape(topic)))
+
+ # Inform user about the result
+ await evt.reply(notice)
+
+ @ntfy.subcommand("list-subscriptions", aliases=("subscriptions",),
+ help="List all active subscriptions for this room.")
+ @ensure_permission
+ async def list_subscriptions(self, evt: MessageEvent) -> None:
+ """
+ List all active subscriptions for this room.
+ :param evt: The message event
+ :return: None
+ """
+ # Get all subscriptions for this room from the database
+ subscriptions = await self.db.get_subscriptions_by_room_id(evt.room_id)
+
+ # Build the message content
+ html_content = ("This room is currently subscribed to these "
+ "topics:%s/%s%s/%s does not exist in the"
+ "database. Please add it first: %s add-subscription "
+ "\\ [access token] " %
+ (escape(server), escape(topic), self.config["command_prefix"]))
+ await evt.reply(notice)
+ return
+
+ # Check if the room is already subscribed to the topic
+ existing_subscriptions = await self.db.get_subscriptions_by_topic_id(
+ db_topic.id)
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if sub:
- await evt.reply("This room is already subscribed to %s/%s", server, topic)
+ notice = await build_notice(
+ "This room is already subscribed to %s/%s." %
+ (escape(server), escape(topic)))
+ await evt.reply(notice)
+
+ # Subscribe the room to the topic if it isn't already subscribed
else:
await self.db.add_subscription(db_topic.id, evt.room_id)
- await evt.reply("Subscribed this room to %s/%s", server, topic)
- await evt.react(WHITE_CHECK_MARK)
if not existing_subscriptions:
- self.subscribe_to_topic(db_topic)
+ self.subscribe_to_topic(db_topic, evt)
- @ntfy.subcommand("unsubscribe", aliases=("unsub",), help="Unsubscribe this room from a ntfy topic.")
- @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
- async def unsubscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None:
- # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
- if not await self.can_use_command(evt):
- return None
+ # Mark the command as successful
+ await self.add_reaction(evt, WHITE_CHECK_MARK)
+
+ @ntfy.subcommand("unsubscribe", aliases=("unsub",),
+ help="Unsubscribe this room from a ntfy topic.")
+ @command.argument("topic", "topic URL",
+ matches="(([a-zA-Z0-9-]{1,63}\\.)+"
+ "[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
+ @ensure_permission
+ async def unsubscribe(self, evt: MessageEvent,
+ topic: Tuple[str, Any]) -> None:
+ """
+ Unsubscribe this room from a ntfy topic.
+ See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex.
+ :param evt:
+ :param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
+ :return:
+ """
+ # Check if the topic exists
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
+ notice = await build_notice(
+ "This room is not subscribed to %s/%s." %
+ (escape(db_topic.server), escape(db_topic.topic)))
if not db_topic:
- await evt.reply("This room is not subscribed to %s/%s", server, topic)
+ await evt.reply(notice)
return
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if not sub:
- await evt.reply("This room is not subscribed to %s/%s", server, topic)
+ await evt.reply(notice)
return
+
+ subscriptions = await self.db.get_subscriptions_by_topic_id(db_topic.id)
+
+ # Unsubscribe the room from the topic
await self.db.remove_subscription(db_topic.id, evt.room_id)
- if not await self.db.get_subscriptions(db_topic.id):
- self.tasks[db_topic.id].cancel()
+
+ # Cancel the subscription task if there are no more subscriptions
+ if len(subscriptions) <= 1:
+ await self.cancel_topic_subscription_task(db_topic.id)
+
+ # Clear the last event id if there are no more subscriptions
+ if not await self.db.get_subscriptions_by_topic_id(db_topic.id):
await self.db.clear_topic_id(db_topic.id)
- await evt.reply("Unsubscribed this room from %s/%s", server, topic)
- await evt.react(WHITE_CHECK_MARK)
+
+ # Notify the user that the room has been unsubscribed
+ notice = await build_notice(
+ "Unsubscribed this room from %s/%s." %
+ (escape(db_topic.server), escape(db_topic.topic)))
+ await evt.reply(notice)
+ await self.add_reaction(evt, WHITE_CHECK_MARK)
async def subscribe_to_topics(self) -> None:
+ """
+ Subscribe to all topics in the database.
+ :return:
+ """
topics = await self.db.get_topics()
for topic in topics:
self.subscribe_to_topic(topic)
- def subscribe_to_topic(self, topic: Topic) -> None:
- def log_task_exc(task: asyncio.Task) -> None:
- t2 = self.tasks.pop(topic.id, None)
- if t2 != task:
- self.log.warn("stored task doesn't match callback")
- if task.done() and not task.cancelled():
- exc = task.exception()
- self.log.exception(
- "Subscription task errored, resubscribing", exc_info=exc)
- self.tasks[topic.id] = self.loop.create_task(
- asyncio.sleep(10.0))
- self.tasks[topic.id].add_done_callback(
- lambda _: self.subscribe_to_topic(topic))
+ def subscribe_to_topic(self, topic: Topic,
+ evt: MessageEvent | None = None) -> None:
+ """
+ Subscribe to a topic.
+ :param topic:
+ :param evt:
+ :return:
+ """
- self.log.info("Subscribing to %s/%s", topic.server, topic.topic)
+ # Create a callback to log the exception if the task fails
+ def log_task_exc(t: asyncio.Task) -> None:
+ t2 = self.tasks.pop(topic.id, None)
+ if t2 != t:
+ self.log.warning("Stored task doesn't match callback")
+ # Log the exception if the task failed
+ if t.done() and not t.cancelled():
+ exc = t.exception()
+ # Handle subscription errors
+ if isinstance(exc, SubscriptionError):
+ self.log.exception("Failed to subscribe to %s/%s",
+ topic.server, topic.topic)
+ self.tasks[topic.id] = self.loop.create_task(
+ self.handle_subscription_error(exc, topic, evt, [
+ self.remove_reaction(evt, WHITE_CHECK_MARK),
+ self.add_reaction(evt, NO_ENTRY),
+ ])
+ )
+ else:
+ # Try to resubscribe if the task errored
+ if evt:
+ self.loop.create_task(self.add_reaction(evt, REPEAT))
+ self.log.exception(
+ "Subscription task errored, resubscribing",
+ exc_info=exc)
+ # Sleep for 10 seconds before resubscribing
+ self.tasks[topic.id] = self.loop.create_task(
+ asyncio.sleep(10.0))
+ # Subscribe to the topic again
+ self.tasks[topic.id].add_done_callback(
+ lambda _: self.subscribe_to_topic(topic, evt))
+
+ # Prepare the URL for the topic
+ self.log.info("Subscribing to %s/%s" % (topic.server, topic.topic))
url = "%s/%s/json" % (topic.server, topic.topic)
+ # Prepend the URL with `https://` if necessary
if not url.startswith(("http://", "https://")):
url = "https://" + url
+ # Add the last event id to the URL if it exists
if topic.last_event_id:
url += "?since=%s" % topic.last_event_id
+ # Subscribe to the topic
self.log.debug("Subscribing to URL %s", url)
task = self.loop.create_task(
- self.run_topic_subscription(topic, url))
+ self.run_topic_subscription(topic, url, evt))
self.tasks[topic.id] = task
task.add_done_callback(log_task_exc)
- async def run_topic_subscription(self, topic: Topic, url: str) -> None:
- headers = {"Authorization": f"Bearer {topic.token}"}
+ async def run_topic_subscription(self, topic: Topic, url: str,
+ evt: MessageEvent | None = None) -> None:
+ """
+ Subscribe to a topic.
+ :param topic:
+ :param url:
+ :param evt:
+ :return:
+ """
+ # Prepare authentication headers in case a token is provided
+ headers = {
+ "Authorization": f"Bearer {topic.token}"} if topic.token else None
+
+ # Subscribe to the topic
async with self.http.get(url, timeout=ClientTimeout(),
- headers=headers if topic.token else None) as resp:
+ headers=headers) as resp:
+
+ # Loop through the response content
while True:
line = await resp.content.readline()
- # convert to string and remove trailing newline
+
+ # Convert to string and remove trailing newline
line = line.decode("utf-8").strip()
+
+ # Break if the line is empty
+ if not line:
+ break
+
+ # Parse the line as JSON
self.log.trace("Received notification: %s", line)
message = json.loads(line)
- if message["event"] != "message":
+
+ # Check if the message is an event
+ if "event" not in message or message["event"] != "message":
+ # Check if the message is an error, else continue
+ if "error" in message:
+ raise SubscriptionError(message)
continue
self.log.debug("Received message event: %s", line)
- # persist the received message id
+
+ # Persist the received last event id
await self.db.update_topic_id(topic.id, message["id"])
- # build matrix message
+ # Build matrix message
html_content = self.build_message_content(
topic.server, message)
text_content = await parse_html(html_content.strip())
-
content = TextMessageEventContent(
- msgtype=MessageType.TEXT,
+ msgtype=MessageType.NOTICE,
format=Format.HTML,
formatted_body=html_content,
body=text_content,
)
- subscriptions = await self.db.get_subscriptions(topic.id)
- for sub in subscriptions:
- try:
- await self.client.send_message(sub.room_id, content)
- except Exception as exc:
- self.log.exception(
- "Failed to send matrix message!", exc_info=exc)
+ # Broadcast the message to all subscribed rooms
+ subscriptions = await self.db.get_subscriptions_by_topic_id(
+ topic.id)
+ room_ids = [sub.room_id for sub in subscriptions]
+ await self.broadcast_to_rooms(room_ids, content)
+
+ async def broadcast_to_rooms(self, room_ids: List[RoomID],
+ content: TextMessageEventContent):
+ """
+ Broadcast a message to multiple rooms concurrently.
+ :param room_ids:
+ :param content:
+ :return:
+ """
+ self.log.info("Broadcasting message to %d rooms: %s",
+ len(room_ids), content.body)
+
+ async def send(rid):
+ try:
+ await self.client.send_message(rid, content)
+ except Exception as e:
+ self.log.exception(
+ "Failed to send matrix message to room with id '%s'!",
+ rid, exc_info=e)
+
+ async with asyncio.TaskGroup() as task_group:
+ for room_id in room_ids:
+ task_group.create_task(send(room_id))
+
+ async def handle_subscription_error(self,
+ subscription_error: SubscriptionError,
+ topic: Topic,
+ evt: MessageEvent,
+ callbacks: List[Awaitable] = None) \
+ -> None:
+ """
+ Handles a subscription error by removing the topic and all related
+ subscriptions from the database and notifying all subscribed rooms.
+ :param subscription_error:
+ :param topic:
+ :param evt:
+ :param callbacks: List of async functions to run afterward
+ :return:
+ """
+ task = self.tasks[topic.id]
+ message = subscription_error.message
+ code = subscription_error.code
+
+ self.log.error("ntfy server responded: '%s' %s",
+ message if message else "unknown error",
+ "(code %s)" % code if code else "")
+
+ # Get affected rooms
+ subscriptions = await self.db.get_subscriptions_by_topic_id(topic.id)
+ # Filter out the current room if there is one
+ if evt:
+ room_ids = [sub.room_id for sub in subscriptions
+ if sub.room_id != evt.room_id]
+ else:
+ room_ids = [sub.room_id for sub in subscriptions]
+
+ match code:
+ # Handle 'unauthorized' and 'limit reached: too many auth failures'
+ case 40101 | 42909:
+
+ # Remove topic and all related subscriptions from database
+ result = await self.db.remove_topic(topic.id)
+ self.log.info("Removed topic: %s", topic.id)
+
+ # Notify all subscribed rooms that the subscription has
+ # been cancelled
+ notice = await build_notice(
+ "Authentication failed, please check the "
+ "access token. Subscription to '%s/%s' has "
+ "been cancelled and topic has been removed." %
+ (escape(topic.topic), escape(topic.server)))
+ case _:
+ notice = await build_notice(
+ "Subscription to '%s/%s' failed: "
+ "'%s'" %
+ (escape(topic.topic), escape(topic.server),
+ escape(message["error"])))
+
+ # Notify the user that the subscription has been cancelled
+ try:
+ await evt.reply(notice)
+ except Exception as exc:
+ self.log.exception(
+ "Failed to send matrix message to room with id '%s'!",
+ evt.room_id, exc_info=exc)
+ # Broadcast the error message to all subscribed rooms
+ await self.broadcast_to_rooms(room_ids, notice)
+
+ # Cancel the task
+ if task.cancel():
+ self.log.info("Subscription task cancelled")
+
+ # Run the callbacks concurrently
+ if callbacks:
+ async with asyncio.TaskGroup() as callback_tasks:
+ for callback in callbacks:
+ callback_tasks.create_task(callback)
def build_message_content(self, server: str, message) -> str:
+ """
+ Build the message content for a ntfy message.
+ :param server:
+ :param message:
+ :return:
+ """
topic = message["topic"]
body = message["message"]
title = message.get("title", None)
@@ -201,33 +624,35 @@ class NtfyBot(Plugin):
else:
emoji = tags = ""
- html_content = "Ntfy message in topic %s/%s" % ( - html.escape(server), html.escape(topic)) + html_content = ("Ntfy message in topic%s/%s" + "") % ( + escape(server), escape(topic)) # build title if title and click: html_content += "" return html_content @@ -239,3 +664,32 @@ class NtfyBot(Plugin): @classmethod def get_db_upgrade_table(cls) -> UpgradeTable | None: return upgrade_table + + async def add_reaction(self, evt: MessageEvent, emoji: str): + """ + Add a reaction to a message and store the event id. + """ + self.log.debug("Adding reaction %s to event %s", + emoji, evt.event_id) + reaction_id = await evt.react(emoji) + self.reactions.update( + {evt.room_id: {evt.event_id: {emoji: reaction_id}}} + ) + + async def remove_reaction(self, evt: MessageEvent, emoji: str) -> None: + """ + Remove a reaction from a message. + :param evt: The original message event + :param emoji: The emoji sent as a reaction + """ + self.log.debug("Removing reaction %s from event %s", + emoji, evt.event_id) + try: + reaction_id = self.reactions[evt.room_id][evt.event_id][emoji] + if reaction_id: + result = await self.client.redact(evt.room_id, reaction_id) + if result: + self.reactions[evt.room_id][evt.event_id].pop(emoji) + except KeyError: + self.log.warning("Reaction %s not found in self.reactions[%s][%s]", + emoji, evt.room_id, evt.event_id) diff --git a/ntfy/db.py b/ntfy/db.py index 7a353c1..e40579e 100644 --- a/ntfy/db.py +++ b/ntfy/db.py @@ -37,13 +37,33 @@ async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: )""" ) +@upgrade_table.register(description="Cascade delete of topics") +async def upgrade_v2(conn: Connection, scheme: Scheme) -> None: + await conn.execute( + """CREATE TABLE subscriptions_new ( + topic_id INTEGER, + room_id TEXT NOT NULL, + + PRIMARY KEY (topic_id, room_id), + FOREIGN KEY (topic_id) REFERENCES topics (id) ON DELETE CASCADE + )""" + ) + await conn.execute( + """INSERT INTO subscriptions_new SELECT * FROM subscriptions""" + ) + await conn.execute( + """DROP TABLE subscriptions""" + ) + await conn.execute( + """ALTER TABLE subscriptions_new RENAME TO subscriptions""" + ) @dataclass class Topic: id: int server: str topic: str - last_event_id: str + last_event_id: str | None = None token: str | None = None subscriptions: List[Subscription] = attr.ib(factory=lambda: []) @@ -71,6 +91,8 @@ class Topic: class Subscription: topic_id: int room_id: RoomID + server: str + topic: str @classmethod def from_row(cls, row: Record | None) -> Topic | None: @@ -78,9 +100,13 @@ class Subscription: return None topic_id = row["topic_id"] room_id = row["room_id"] + server = row["server"] + topic = row["topic"] return cls( topic_id=topic_id, room_id=room_id, + server=server, + topic=topic, ) @@ -145,6 +171,13 @@ class DB: ) return topic + async def remove_topic(self, topic_id: int): + query = """ + DELETE FROM topics + WHERE id = $1 + """ + return await self.db.execute(query, topic_id) + async def get_topic(self, server: str, topic: str) -> Topic | None: query = """ SELECT id, server, topic, token, last_event_id @@ -165,10 +198,12 @@ class DB: row = await self.db.fetchrow(query, topic_id, room_id) return (Subscription.from_row(row), Topic.from_row(row)) - async def get_subscriptions(self, topic_id: int) -> List[Subscription]: + async def get_subscriptions_by_topic_id(self, topic_id: int) \ + -> List[Subscription]: query = """ - SELECT topic_id, room_id + SELECT topic_id, room_id, server, topic FROM subscriptions + INNER JOIN topics ON topics.id = subscriptions.topic_id WHERE topic_id = $1 """ rows = await self.db.fetch(query, topic_id) @@ -177,6 +212,20 @@ class DB: subscriptions.append(Subscription.from_row(row)) return subscriptions + async def get_subscriptions_by_room_id(self, room_id: RoomID) \ + -> List[Subscription]: + query = """ + SELECT topic_id, room_id, server, topic + FROM subscriptions + INNER JOIN topics ON topics.id = subscriptions.topic_id + WHERE room_id = $1 + """ + rows = await self.db.fetch(query, room_id) + subscriptions = [] + for row in rows: + subscriptions.append(Subscription.from_row(row)) + return subscriptions + async def add_subscription(self, topic_id: int, room_id: RoomID) -> None: query = """ INSERT INTO subscriptions (topic_id, room_id) diff --git a/ntfy/emoji.py b/ntfy/emoji.py index f938abe..f796cfe 100644 --- a/ntfy/emoji.py +++ b/ntfy/emoji.py @@ -32,6 +32,9 @@ except ImportError: EMOJI_FALLBACK = True WHITE_CHECK_MARK = emoji.emojize(":white_check_mark:", language="alias") +REPEAT = emoji.emojize(":repeat:", language="alias") +NO_ENTRY = emoji.emojize(":no_entry:", language="alias") +WARNING = emoji.emojize(":warning:", language="alias") def parse_tags(log: TraceLogger, tags: List[str]) -> Tuple[List[str], List[str]]:%s%s
" % ( - emoji, html.escape(click), html.escape(title)) + emoji, escape(click), escape(title)) emoji = "" elif title: - html_content += "%s%s
" % (emoji, html.escape(title)) + html_content += "%s%s
" % (emoji, escape(title)) emoji = "" # build body if click and not title: html_content += "%s%s" % ( - emoji, html.escape(click), html.escape(body).replace("\n", "
")) + emoji, escape(click), + escape(body).replace("\n", "
")) else: - html_content += emoji + html.escape(body).replace("\n", "
") + html_content += emoji + escape(body).replace("\n", "
") # add non-emoji tags if tags: - html_content += "
Tags:%s" % html.escape( + html_content += "
Tags:%s" % escape( tags) # build attachment if attachment: - html_content += "
View %s" % (html.escape( - attachment["url"]), html.escape(attachment["name"])) + html_content += "
View %s" % (escape( + attachment["url"]), escape(attachment["name"])) html_content += "