implement authentication via access token

This commit is contained in:
Marc Koch 2024-06-20 18:57:12 +02:00
parent 256aa8f315
commit a3505d3d2e
2 changed files with 19 additions and 9 deletions

View File

@ -76,14 +76,15 @@ class NtfyBot(Plugin):
@ntfy.subcommand("subscribe", aliases=("sub",), help="Subscribe this room to a ntfy topic.") @ntfy.subcommand("subscribe", aliases=("sub",), help="Subscribe this room to a ntfy topic.")
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None: @command.argument("token", "access token", required=False, matches="tk_[a-zA-Z0-9]{29}")
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any], token: Tuple[str, Any]) -> None:
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
if not await self.can_use_command(evt): if not await self.can_use_command(evt):
return None return None
server, topic = topic[0].split("/") server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic) db_topic = await self.db.get_topic(server, topic)
if not db_topic: if not db_topic:
db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, last_event_id=None)) db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, token=token, last_event_id=None))
existing_subscriptions = await self.db.get_subscriptions(db_topic.id) existing_subscriptions = await self.db.get_subscriptions(db_topic.id)
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id) sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if sub: if sub:
@ -150,7 +151,9 @@ class NtfyBot(Plugin):
task.add_done_callback(log_task_exc) task.add_done_callback(log_task_exc)
async def run_topic_subscription(self, topic: Topic, url: str) -> None: async def run_topic_subscription(self, topic: Topic, url: str) -> None:
async with self.http.get(url, timeout=ClientTimeout()) as resp: headers = {"Authorization": f"Bearer {topic.token}"}
async with self.http.get(url, timeout=ClientTimeout(),
headers=headers if topic.token else None) as resp:
while True: while True:
line = await resp.content.readline() line = await resp.content.readline()
# convert to string and remove trailing newline # convert to string and remove trailing newline

View File

@ -20,6 +20,7 @@ async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
id INTEGER {gen}, id INTEGER {gen},
server TEXT NOT NULL, server TEXT NOT NULL,
topic TEXT NOT NULL, topic TEXT NOT NULL,
token TEXT,
last_event_id TEXT, last_event_id TEXT,
PRIMARY KEY (id), PRIMARY KEY (id),
@ -43,6 +44,7 @@ class Topic:
server: str server: str
topic: str topic: str
last_event_id: str last_event_id: str
token: str | None = None
subscriptions: List[Subscription] = attr.ib(factory=lambda: []) subscriptions: List[Subscription] = attr.ib(factory=lambda: [])
@ -53,11 +55,13 @@ class Topic:
id = row["id"] id = row["id"]
server = row["server"] server = row["server"]
topic = row["topic"] topic = row["topic"]
token = row["token"]
last_event_id = row["last_event_id"] last_event_id = row["last_event_id"]
return cls( return cls(
id=id, id=id,
server=server, server=server,
topic=topic, topic=topic,
token=token,
last_event_id=last_event_id, last_event_id=last_event_id,
subscriptions=[] subscriptions=[]
) )
@ -90,7 +94,7 @@ class DB:
async def get_topics(self) -> List[Topic]: async def get_topics(self) -> List[Topic]:
query = """ query = """
SELECT id, server, topic, last_event_id, topic_id, room_id SELECT id, server, topic, token, last_event_id, topic_id, room_id
FROM topics FROM topics
INNER JOIN INNER JOIN
subscriptions ON topics.id = subscriptions.topic_id subscriptions ON topics.id = subscriptions.topic_id
@ -119,14 +123,15 @@ class DB:
async def create_topic(self, topic: Topic) -> Topic: async def create_topic(self, topic: Topic) -> Topic:
query = """ query = """
INSERT INTO topics (server, topic, last_event_id) INSERT INTO topics (server, topic, token, last_event_id)
VALUES ($1, $2, $3) RETURNING (id) VALUES ($1, $2, $3, $4) RETURNING (id)
""" """
if self.db.scheme == Scheme.SQLITE: if self.db.scheme == Scheme.SQLITE:
cur = await self.db.execute( cur = await self.db.execute(
query.replace("RETURNING (id)", ""), query.replace("RETURNING (id)", ""),
topic.server, topic.server,
topic.topic, topic.topic,
topic.token,
topic.last_event_id, topic.last_event_id,
) )
topic.id = cur.lastrowid topic.id = cur.lastrowid
@ -135,21 +140,23 @@ class DB:
query, query,
topic.server, topic.server,
topic.topic, topic.topic,
topic.token,
topic.last_event_id, topic.last_event_id,
) )
return topic return topic
async def get_topic(self, server: str, topic: str) -> Topic | None: async def get_topic(self, server: str, topic: str) -> Topic | None:
query = """ query = """
SELECT id, server, topic, last_event_id SELECT id, server, topic, token, last_event_id
FROM topics FROM topics
WHERE server = $1 AND topic = $2 WHERE server = $1 AND topic = $2
""" """
return Topic.from_row(await self.db.fetchrow(query, server, topic)) return Topic.from_row(await self.db.fetchrow(query, server, topic))
async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[Subscription | None, Topic | None]: async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[
Subscription | None, Topic | None]:
query = """ query = """
SELECT id, server, topic, last_event_id, topic_id, room_id SELECT id, server, topic, token, last_event_id, topic_id, room_id
FROM topics FROM topics
INNER JOIN INNER JOIN
subscriptions ON topics.id = subscriptions.topic_id AND subscriptions.room_id = $2 subscriptions ON topics.id = subscriptions.topic_id AND subscriptions.room_id = $2