diff --git a/mautrix_telegram/db/message.py b/mautrix_telegram/db/message.py index 5d907d02..1180a1a8 100644 --- a/mautrix_telegram/db/message.py +++ b/mautrix_telegram/db/message.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, ClassVar from asyncpg import Record from attr import dataclass -from mautrix.types import EventID, RoomID +from mautrix.types import EventID, RoomID, UserID from mautrix.util.async_db import Database, Scheme from ..types import TelegramID @@ -39,6 +39,8 @@ class Message: edit_index: int redacted: bool = False content_hash: bytes | None = None + sender_mxid: UserID | None = None + sender: TelegramID | None = None @classmethod def _from_row(cls, row: Record | None) -> Message | None: @@ -46,7 +48,19 @@ class Message: return None return cls(**row) - columns: ClassVar[str] = "mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash" + columns: ClassVar[str] = ", ".join( + ( + "mxid", + "mx_room", + "tgid", + "tg_space", + "edit_index", + "redacted", + "content_hash", + "sender_mxid", + "sender", + ) + ) @classmethod async def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> list[Message]: @@ -158,12 +172,16 @@ class Message: self.edit_index, self.redacted, self.content_hash, + self.sender_mxid, + self.sender, ) async def insert(self) -> None: q = """ - INSERT INTO message (mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO message ( + mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash, + sender_mxid, sender + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) """ await self.db.execute(q, *self._values) diff --git a/mautrix_telegram/db/upgrade/__init__.py b/mautrix_telegram/db/upgrade/__init__.py index fd13a923..540c4ebf 100644 --- a/mautrix_telegram/db/upgrade/__init__.py +++ b/mautrix_telegram/db/upgrade/__init__.py @@ -14,4 +14,5 @@ from . import ( v09_puppet_username_index, v10_more_backfill_fields, v11_backfill_queue, + v12_message_sender, ) diff --git a/mautrix_telegram/db/upgrade/v12_message_sender.py b/mautrix_telegram/db/upgrade/v12_message_sender.py new file mode 100644 index 00000000..c32a513e --- /dev/null +++ b/mautrix_telegram/db/upgrade/v12_message_sender.py @@ -0,0 +1,24 @@ +# mautrix-telegram - A Matrix-Telegram puppeting bridge +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from mautrix.util.async_db import Connection + +from . import upgrade_table + + +@upgrade_table.register(description="Store sender in message table") +async def upgrade_v12(conn: Connection) -> None: + await conn.execute("ALTER TABLE message ADD COLUMN sender_mxid TEXT") + await conn.execute("ALTER TABLE message ADD COLUMN sender BIGINT") diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index e92c272f..c2320a81 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -1598,7 +1598,14 @@ class Portal(DBPortal, BasePortal): link_preview=lp, ) await self._mark_matrix_handled( - sender, EventType.ROOM_MESSAGE, event_id, space, -1, resp, content.msgtype + sender, + sender_id, + EventType.ROOM_MESSAGE, + event_id, + space, + -1, + resp, + content.msgtype, ) return response = await client.send_message( @@ -1609,7 +1616,14 @@ class Portal(DBPortal, BasePortal): link_preview=lp, ) await self._mark_matrix_handled( - sender, EventType.ROOM_MESSAGE, event_id, space, 0, response, content.msgtype + sender, + sender_id, + EventType.ROOM_MESSAGE, + event_id, + space, + 0, + response, + content.msgtype, ) async def _handle_matrix_file( @@ -1736,12 +1750,20 @@ class Portal(DBPortal, BasePortal): raise else: await self._mark_matrix_handled( - sender, EventType.ROOM_MESSAGE, event_id, space, 0, response, content.msgtype + sender, + sender_id, + EventType.ROOM_MESSAGE, + event_id, + space, + 0, + response, + content.msgtype, ) async def _matrix_document_edit( self, sender: u.User, + sender_tgid: TelegramID, client: MautrixTelegramClient, content: MessageEventContent, space: TelegramID, @@ -1761,7 +1783,14 @@ class Portal(DBPortal, BasePortal): file=media, ) await self._mark_matrix_handled( - sender, EventType.ROOM_MESSAGE, event_id, space, -1, response, content.msgtype + sender, + sender_tgid, + EventType.ROOM_MESSAGE, + event_id, + space, + -1, + response, + content.msgtype, ) return True return False @@ -1792,7 +1821,7 @@ class Portal(DBPortal, BasePortal): async with self.send_lock(sender_id): if await self._matrix_document_edit( - sender, client, content, space, caption, media, event_id + sender, sender_id, client, content, space, caption, entities, media, event_id ): return try: @@ -1803,12 +1832,20 @@ class Portal(DBPortal, BasePortal): raise else: await self._mark_matrix_handled( - sender, EventType.ROOM_MESSAGE, event_id, space, 0, response, content.msgtype + sender, + sender_id, + EventType.ROOM_MESSAGE, + event_id, + space, + 0, + response, + content.msgtype, ) async def _mark_matrix_handled( self, sender: u.User, + sender_tgid: TelegramID, event_type: EventType, event_id: EventID, space: TelegramID, @@ -1828,6 +1865,8 @@ class Portal(DBPortal, BasePortal): mxid=event_id, edit_index=edit_index, content_hash=event_hash, + sender_mxid=sender.mxid, + sender=sender_tgid, ).insert() sender.send_remote_checkpoint( MessageSendCheckpointStatus.SUCCESS, @@ -1985,7 +2024,14 @@ class Portal(DBPortal, BasePortal): return False else: await self._mark_matrix_handled( - sender, EventType.ROOM_MESSAGE, event_id, space, 0, response[0], msgtype + sender, + sender.tgid, + EventType.ROOM_MESSAGE, + event_id, + space, + 0, + response[0], + msgtype, ) return True @@ -2427,6 +2473,7 @@ class Portal(DBPortal, BasePortal): tgid=TelegramID(evt.id), edit_index=prev_edit_msg.edit_index + 1, content_hash=event_hash, + sender=sender.id, ).insert() return @@ -2467,6 +2514,7 @@ class Portal(DBPortal, BasePortal): tgid=TelegramID(evt.id), edit_index=prev_edit_msg.edit_index + 1, content_hash=event_hash, + sender=sender.id, ).insert() await DBMessage.replace_temp_mxid(temporary_identifier, self.mxid, event_id) @@ -2822,6 +2870,7 @@ class Portal(DBPortal, BasePortal): tg_space=tg_space, edit_index=0, content_hash=event_hash, + sender=sender.id, ).insert() return @@ -2899,6 +2948,7 @@ class Portal(DBPortal, BasePortal): tg_space=tg_space, edit_index=0, content_hash=event_hash, + sender=sender.id, ) await dbm.insert() await DBMessage.replace_temp_mxid(temporary_identifier, self.mxid, event_id) @@ -3021,6 +3071,7 @@ class Portal(DBPortal, BasePortal): mxid=event_id, tg_space=source.tgid, edit_index=0, + sender=sender.id, ).insert() if self.config["bridge.always_read_joined_telegram_notice"]: double_puppet = await p.Puppet.get_by_tgid(source.tgid)