maubot-ntfy/ntfy/bot.py

696 lines
28 KiB
Python

import asyncio
from html import escape
import json
from typing import Any, Dict, Tuple, List, Awaitable, Callable
from aiohttp import ClientTimeout
from maubot import MessageEvent, Plugin
from maubot.handlers import command
from mautrix.types import (EventType, Format, MessageType,
TextMessageEventContent, RoomID, EventID)
from mautrix.util.async_db import UpgradeTable
from mautrix.util.config import BaseProxyConfig
from mautrix.util.formatter import parse_html
from .config import Config
from .db import DB, Topic, upgrade_table
from .emoji import (EMOJI_FALLBACK, WHITE_CHECK_MARK, REPEAT, NO_ENTRY,
WARNING, parse_tags)
from .exceptions import SubscriptionError
async def build_notice(html_content) -> TextMessageEventContent:
"""
Build a notice message.
"""
text_content = await parse_html(html_content.strip())
return TextMessageEventContent(
msgtype=MessageType.NOTICE,
format=Format.HTML,
formatted_body=html_content,
body=text_content,
)
def ensure_permission(func: Callable):
"""
Decorator function to ensure that the user has the required permission
to execute a command.
"""
async def wrapper(self, *args, **kwargs):
evt = args[0]
if evt.sender in self.config["admins"]:
return await func(self, *args, **kwargs)
levels = await self.client.get_state_event(evt.room_id,
EventType.ROOM_POWER_LEVELS)
user_level = levels.get_user_level(evt.sender)
notice = await build_notice(
"You don't have the permission to manage the ntfy bot.")
await evt.reply(notice)
return
return wrapper
class NtfyBot(Plugin):
db: DB
config: Config
tasks: Dict[int, asyncio.Task] = {}
reactions: Dict[RoomID, Dict[str, Dict[str, EventID]]] = {}
async def start(self) -> None:
await super().start()
self.config.load_and_update()
self.db = DB(self.database, self.log)
if EMOJI_FALLBACK:
self.log.warning(
"Please install the `emoji` package for full emoji support")
await self.resubscribe()
async def stop(self) -> None:
await super().stop()
await self.clear_subscriptions()
async def on_external_config_update(self) -> None:
self.log.info("Refreshing configuration")
self.config.load_and_update()
async def resubscribe(self) -> None:
"""
Clear all subscriptions and resubscribe to all topics.
:return:
"""
await self.clear_subscriptions()
await self.subscribe_to_topics()
async def clear_subscriptions(self) -> None:
"""
Cancel all subscription tasks and clear the task list.
:return:
"""
for topic_id in self.tasks.keys():
await self.cancel_topic_subscription_task(topic_id)
async def cancel_topic_subscription_task(self, topic_id: int):
"""
Cancel a topic subscription task and handle any exceptions.
:param topic_id: The topic id to cancel the subscription task for
:return:
"""
task = self.tasks.get(topic_id)
if task and not task.done():
self.log.debug("Cancelling subscription task for topic %d",
async def can_use_command(self, evt: MessageEvent) -> bool:
if evt.sender in self.config["admins"]:
return True
levels = await self.client.get_state_event(evt.room_id, EventType.ROOM_POWER_LEVELS)
user_level = levels.get_user_level(evt.sender)
if user_level < 50:
await evt.reply("You don't have the permission to manage ntfy subscriptions in this room.")
return False
return True
@command.new(name=lambda self: self.config["command_prefix"],
help="Manage ntfy subscriptions.", require_subcommand=True)
async def ntfy(self) -> None:
pass
@ntfy.subcommand("list-topics", aliases=("topics",),
help="List all ntfy topics in the database.")
@ensure_permission
async def list_topics(self, evt: MessageEvent) -> None:
"""
List all topics in the database.
:param evt: The message event
:return: None
"""
# Get all topics from the database
topics = await self.db.get_topics()
# Build the message content
if topics:
html_content = "<span>Currently available ntfy topics:</span><ul>"
for topic in topics:
html_content += "<li><code>%s/%s</code></li>" % (
escape(topic.server), escape(topic.topic))
html_content += "</ul>"
else:
html_content = "<span>No topics available yet.</span>"
# Send the message to the user
content = await build_notice(html_content)
await evt.reply(content)
@ntfy.subcommand("add-topic", aliases=("add",),
help="Add a new topic to the database.")
@command.argument("topic", "topic URL",
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
@command.argument("token", "access token", required=False,
matches="tk_[a-zA-Z0-9]{29}")
@ensure_permission
async def add_topic(self, evt: MessageEvent, topic: Tuple[str, Any],
token: Tuple[str, Any]) -> None:
"""
Add a new topic to the database.
:param evt: The message event
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
:param token: The access token for the topic
(e.g. `tk_be6uc2ca1x1orakcwg1j3hp0ylot6`)
:return: None
"""
# Look up the topic in the database
server, topic = topic[0].split("/")
token = token if token else None
db_topic = await self.db.get_topic(server, topic)
# Check if the topic already exists
if db_topic:
# If it already exists, check if the token is the same
if db_topic.token == token:
await self.add_reaction(evt, NO_ENTRY)
notice = await build_notice(
"The topic <code>%s/%s</code> already exists." %
(escape(server), escape(topic)))
await evt.reply(notice)
return
# Else remove the topic and delete all corresponding subscriptions
else:
subscriptions = await self.db.get_subscriptions_by_topic_id(
db_topic.id)
await self.cancel_topic_subscription_task(db_topic.id)
await self.db.remove_topic(db_topic.id)
db_topic = None
await self.add_reaction(evt, WARNING)
notice = await build_notice(
"The topic <code>%s/%s</code> already exists "
"with a different token. The old entry with the different "
"token and any associated subscriptions have been removed "
"and would need to be re-subscribed." %
(escape(server), escape(topic)))
# Inform user about the result
await evt.reply(notice)
# Broadcast a notice to all subscribed rooms
broadcast_notice = await build_notice(
"The topic <code>%s/%s</code> has been updated with a new "
"access token. Please re-subscribe to this topic." %
(escape(server), escape(topic)))
room_ids = [sub.room_id for sub in subscriptions
if sub.room_id != evt.room_id]
await self.broadcast_to_rooms(room_ids, broadcast_notice)
# Create the topic if it doesn't exist
db_topic = await self.db.create_topic(
Topic(id=-1, server=server, topic=topic, token=token,
last_event_id=None))
# Mark the command as successful
await self.add_reaction(evt, WHITE_CHECK_MARK)
@ntfy.subcommand("remove-topic", aliases=("remove",),
help="Remove a topic from the database.")
@command.argument("topic", "topic URL",
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
@ensure_permission
async def remove_topic(self, evt: MessageEvent,
topic: Tuple[str, Any]) -> None:
"""
Remove a topic from the database.
:param evt: The message event
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
:return: None
"""
# Lookup topic from db
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
# Delete the topic if it exists and cancel all corresponding
# subscriptions
if db_topic:
subscriptions = await self.db.get_subscriptions_by_topic_id(
db_topic.id)
await self.cancel_topic_subscription_task(db_topic.id)
await self.db.remove_topic(db_topic.id)
await self.add_reaction(evt, WHITE_CHECK_MARK)
notice = await build_notice(
"The topic <code>%s/%s</code> has been deleted from the "
"database and all associated subscriptions have been "
"cancelled." % (escape(server), escape(topic)))
# Broadcast a notice to all subscribed rooms
broadcast_notice = await build_notice(
"The topic <code>%s/%s</code> has been removed from the "
"database. The room has been unsubscribed from this topic." %
(escape(server), escape(topic)))
room_ids = [sub.room_id for sub in subscriptions
if sub.room_id != evt.room_id]
await self.broadcast_to_rooms(room_ids, broadcast_notice)
# Inform the user if the topic doesn't exist
else:
await self.add_reaction(evt, NO_ENTRY)
notice = await build_notice(
"The topic <code>%s/%s</code> does not exist in the database." %
(escape(server), escape(topic)))
# Inform user about the result
await evt.reply(notice)
@ntfy.subcommand("list-subscriptions", aliases=("subscriptions",),
help="List all active subscriptions for this room.")
@ensure_permission
async def list_subscriptions(self, evt: MessageEvent) -> None:
"""
List all active subscriptions for this room.
:param evt: The message event
:return: None
"""
# Get all subscriptions for this room from the database
subscriptions = await self.db.get_subscriptions_by_room_id(evt.room_id)
# Build the message content
html_content = ("<span>This room is currently subscribed to these "
"topics:</span><ul>")
for sub in subscriptions:
html_content += "<li><code>%s/%s</code></li>" % (
escape(sub.server), escape(sub.topic))
html_content += "</ul>"
# Send the message to the user
content = await build_notice(html_content)
await evt.reply(content)
@ntfy.subcommand("subscribe", aliases=("sub",),
help="Subscribe this room to a ntfy topic.")
@command.argument("topic", "topic URL",
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
@ensure_permission
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]):
"""
Subscribe this room to a ntfy topic.
See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61
for the valid topic regex.
:param evt:
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
:return: None
"""
# Check if the topic already exists
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
# If it doesn't exist, tell user to add it first
if not db_topic:
await self.add_reaction(evt, NO_ENTRY)
notice = await build_notice(
"The topic <code>%s/%s</code> does not exist in the"
"database. Please add it first: <code>%s add-subscription "
"\\<topic URL\\> [access token]</code>" %
(escape(server), escape(topic), self.config["command_prefix"]))
await evt.reply(notice)
return
# Check if the room is already subscribed to the topic
existing_subscriptions = await self.db.get_subscriptions_by_topic_id(
db_topic.id)
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if sub:
notice = await build_notice(
"This room is already subscribed to <code>%s/%s</code>." %
(escape(server), escape(topic)))
await evt.reply(notice)
# Subscribe the room to the topic if it isn't already subscribed
else:
await self.db.add_subscription(db_topic.id, evt.room_id)
if not existing_subscriptions:
self.subscribe_to_topic(db_topic, evt)
# Mark the command as successful
await self.add_reaction(evt, WHITE_CHECK_MARK)
@ntfy.subcommand("unsubscribe", aliases=("unsub",),
help="Unsubscribe this room from a ntfy topic.")
@command.argument("topic", "topic URL",
matches="(([a-zA-Z0-9-]{1,63}\\.)+"
"[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
@ensure_permission
async def unsubscribe(self, evt: MessageEvent,
topic: Tuple[str, Any]) -> None:
"""
Unsubscribe this room from a ntfy topic.
See https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex.
:param evt:
:param topic: The topic URL (e.g. `ntfy.sh/my_topic`)
:return:
"""
# Check if the topic exists
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
notice = await build_notice(
"This room is not subscribed to <code>%s/%s</code>." %
(escape(db_topic.server), escape(db_topic.topic)))
if not db_topic:
await evt.reply(notice)
return
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if not sub:
await evt.reply(notice)
return
subscriptions = await self.db.get_subscriptions_by_topic_id(db_topic.id)
# Unsubscribe the room from the topic
await self.db.remove_subscription(db_topic.id, evt.room_id)
# Cancel the subscription task if there are no more subscriptions
if len(subscriptions) <= 1:
await self.cancel_topic_subscription_task(db_topic.id)
# Clear the last event id if there are no more subscriptions
if not await self.db.get_subscriptions_by_topic_id(db_topic.id):
await self.db.clear_topic_id(db_topic.id)
# Notify the user that the room has been unsubscribed
notice = await build_notice(
"Unsubscribed this room from <code>%s/%s</code>." %
(escape(db_topic.server), escape(db_topic.topic)))
await evt.reply(notice)
await self.add_reaction(evt, WHITE_CHECK_MARK)
async def subscribe_to_topics(self) -> None:
"""
Subscribe to all topics in the database.
:return:
"""
topics = await self.db.get_topics()
for topic in topics:
self.subscribe_to_topic(topic)
def subscribe_to_topic(self, topic: Topic,
evt: MessageEvent | None = None) -> None:
"""
Subscribe to a topic.
:param topic:
:param evt:
:return:
"""
# Create a callback to log the exception if the task fails
def log_task_exc(t: asyncio.Task) -> None:
t2 = self.tasks.pop(topic.id, None)
if t2 != t:
self.log.warning("Stored task doesn't match callback")
# Log the exception if the task failed
if t.done() and not t.cancelled():
exc = t.exception()
# Handle subscription errors
if isinstance(exc, SubscriptionError):
self.log.exception("Failed to subscribe to %s/%s",
topic.server, topic.topic)
self.tasks[topic.id] = self.loop.create_task(
self.handle_subscription_error(exc, topic, evt, [
self.remove_reaction(evt, WHITE_CHECK_MARK),
self.add_reaction(evt, NO_ENTRY),
])
)
else:
# Try to resubscribe if the task errored
if evt:
self.loop.create_task(self.add_reaction(evt, REPEAT))
self.log.exception(
"Subscription task errored, resubscribing",
exc_info=exc)
# Sleep for 10 seconds before resubscribing
self.tasks[topic.id] = self.loop.create_task(
asyncio.sleep(10.0))
# Subscribe to the topic again
self.tasks[topic.id].add_done_callback(
lambda _: self.subscribe_to_topic(topic, evt))
# Prepare the URL for the topic
self.log.info("Subscribing to %s/%s" % (topic.server, topic.topic))
url = "%s/%s/json" % (topic.server, topic.topic)
# Prepend the URL with `https://` if necessary
if not url.startswith(("http://", "https://")):
url = "https://" + url
# Add the last event id to the URL if it exists
if topic.last_event_id:
url += "?since=%s" % topic.last_event_id
# Subscribe to the topic
self.log.debug("Subscribing to URL %s", url)
task = self.loop.create_task(
self.run_topic_subscription(topic, url, evt))
self.tasks[topic.id] = task
task.add_done_callback(log_task_exc)
async def run_topic_subscription(self, topic: Topic, url: str,
evt: MessageEvent | None = None) -> None:
"""
Subscribe to a topic.
:param topic:
:param url:
:param evt:
:return:
"""
# Prepare authentication headers in case a token is provided
headers = {
"Authorization": f"Bearer {topic.token}"} if topic.token else None
# Subscribe to the topic
async with self.http.get(url, timeout=ClientTimeout(),
headers=headers) as resp:
# Loop through the response content
while True:
line = await resp.content.readline()
# Convert to string and remove trailing newline
line = line.decode("utf-8").strip()
# Break if the line is empty
if not line:
break
# Parse the line as JSON
self.log.trace("Received notification: %s", line)
message = json.loads(line)
# Check if the message is an event
if "event" not in message or message["event"] != "message":
# Check if the message is an error, else continue
if "error" in message:
raise SubscriptionError(message)
continue
self.log.debug("Received message event: %s", line)
# Persist the received last event id
await self.db.update_topic_id(topic.id, message["id"])
# Build matrix message
html_content = self.build_message_content(
topic.server, message)
text_content = await parse_html(html_content.strip())
content = TextMessageEventContent(
msgtype=MessageType.NOTICE,
format=Format.HTML,
formatted_body=html_content,
body=text_content,
)
# Broadcast the message to all subscribed rooms
subscriptions = await self.db.get_subscriptions_by_topic_id(
topic.id)
room_ids = [sub.room_id for sub in subscriptions]
await self.broadcast_to_rooms(room_ids, content)
async def broadcast_to_rooms(self, room_ids: List[RoomID],
content: TextMessageEventContent):
"""
Broadcast a message to multiple rooms concurrently.
:param room_ids:
:param content:
:return:
"""
self.log.info("Broadcasting message to %d rooms: %s",
len(room_ids), content.body)
async def send(rid):
try:
await self.client.send_message(rid, content)
except Exception as e:
self.log.exception(
"Failed to send matrix message to room with id '%s'!",
rid, exc_info=e)
async with asyncio.TaskGroup() as task_group:
for room_id in room_ids:
task_group.create_task(send(room_id))
async def handle_subscription_error(self,
subscription_error: SubscriptionError,
topic: Topic,
evt: MessageEvent,
callbacks: List[Awaitable] = None) \
-> None:
"""
Handles a subscription error by removing the topic and all related
subscriptions from the database and notifying all subscribed rooms.
:param subscription_error:
:param topic:
:param evt:
:param callbacks: List of async functions to run afterward
:return:
"""
task = self.tasks[topic.id]
message = subscription_error.message
code = subscription_error.code
self.log.error("ntfy server responded: '%s' %s",
message if message else "unknown error",
"(code %s)" % code if code else "")
# Get affected rooms
subscriptions = await self.db.get_subscriptions_by_topic_id(topic.id)
# Filter out the current room if there is one
if evt:
room_ids = [sub.room_id for sub in subscriptions
if sub.room_id != evt.room_id]
else:
room_ids = [sub.room_id for sub in subscriptions]
match code:
# Handle 'unauthorized' and 'limit reached: too many auth failures'
case 40101 | 42909:
# Remove topic and all related subscriptions from database
result = await self.db.remove_topic(topic.id)
self.log.info("Removed topic: %s", topic.id)
# Notify all subscribed rooms that the subscription has
# been cancelled
notice = await build_notice(
"<span>Authentication failed, please check the "
"access token. Subscription to <code>'%s/%s'</code> has "
"been cancelled and topic has been removed.</span>" %
(escape(topic.topic), escape(topic.server)))
case _:
notice = await build_notice(
"<span>Subscription to <code>'%s/%s'</code> failed: "
"<code>'%s'</code></span>" %
(escape(topic.topic), escape(topic.server),
escape(message["error"])))
# Notify the user that the subscription has been cancelled
try:
await evt.reply(notice)
except Exception as exc:
self.log.exception(
"Failed to send matrix message to room with id '%s'!",
evt.room_id, exc_info=exc)
# Broadcast the error message to all subscribed rooms
await self.broadcast_to_rooms(room_ids, notice)
# Cancel the task
if task.cancel():
self.log.info("Subscription task cancelled")
# Run the callbacks concurrently
if callbacks:
async with asyncio.TaskGroup() as callback_tasks:
for callback in callbacks:
callback_tasks.create_task(callback)
def build_message_content(self, server: str, message) -> str:
"""
Build the message content for a ntfy message.
:param server:
:param message:
:return:
"""
topic = message["topic"]
body = message["message"]
title = message.get("title", None)
tags = message.get("tags", None)
click = message.get("click", None)
attachment = message.get("attachment", None)
if tags:
(emoji, non_emoji) = parse_tags(self.log, tags)
emoji = "".join(emoji) + " "
tags = ", ".join(non_emoji)
else:
emoji = tags = ""
html_content = ("<span>Ntfy message in topic <code>%s/%s</code></span>"
"<blockquote>") % (
escape(server), escape(topic))
# build title
if title and click:
html_content += "<h4>%s<a href=\"%s\">%s</a></h4>" % (
emoji, escape(click), escape(title))
emoji = ""
elif title:
html_content += "<h4>%s%s</h4>" % (emoji, escape(title))
emoji = ""
# build body
if click and not title:
html_content += "%s<a href=\"%s\">%s</a>" % (
emoji, escape(click),
escape(body).replace("\n", "<br />"))
else:
html_content += emoji + escape(body).replace("\n", "<br />")
# add non-emoji tags
if tags:
html_content += "<br/><small>Tags: <code>%s</code></small>" % escape(
tags)
# build attachment
if attachment:
html_content += "<br/><a href=\"%s\">View %s</a>" % (escape(
attachment["url"]), escape(attachment["name"]))
html_content += "</blockquote>"
return html_content
@classmethod
def get_config_class(cls) -> type[BaseProxyConfig]:
return Config
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable | None:
return upgrade_table
async def add_reaction(self, evt: MessageEvent, emoji: str):
"""
Add a reaction to a message and store the event id.
"""
self.log.debug("Adding reaction %s to event %s",
emoji, evt.event_id)
reaction_id = await evt.react(emoji)
self.reactions.update(
{evt.room_id: {evt.event_id: {emoji: reaction_id}}}
)
async def remove_reaction(self, evt: MessageEvent, emoji: str) -> None:
"""
Remove a reaction from a message.
:param evt: The original message event
:param emoji: The emoji sent as a reaction
"""
self.log.debug("Removing reaction %s from event %s",
emoji, evt.event_id)
try:
reaction_id = self.reactions[evt.room_id][evt.event_id][emoji]
if reaction_id:
result = await self.client.redact(evt.room_id, reaction_id)
if result:
self.reactions[evt.room_id][evt.event_id].pop(emoji)
except KeyError:
self.log.warning("Reaction %s not found in self.reactions[%s][%s]",
emoji, evt.room_id, evt.event_id)