Update bridge info when portal metadata changes

This commit is contained in:
Tulir Asokan
2020-06-15 14:43:38 +03:00
parent 482a52cb5e
commit 8a99e67c6d
8 changed files with 116 additions and 23 deletions
+4 -3
View File
@@ -21,7 +21,6 @@ mxtg_config = Config(mxtg_config_path, None, None)
mxtg_config.load() mxtg_config.load()
config.set_main_option("sqlalchemy.url", mxtg_config["appservice.database"].replace("%", "%%")) config.set_main_option("sqlalchemy.url", mxtg_config["appservice.database"].replace("%", "%%"))
AlchemySessionContainer.create_table_classes(None, "telethon_", Base) AlchemySessionContainer.create_table_classes(None, "telethon_", Base)
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
@@ -55,7 +54,8 @@ def run_migrations_offline():
""" """
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option("sqlalchemy.url")
context.configure( context.configure(
url=url, target_metadata=target_metadata, literal_binds=True) url=url, target_metadata=target_metadata, literal_binds=True,
render_as_batch=True)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
@@ -76,7 +76,8 @@ def run_migrations_online():
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=target_metadata target_metadata=target_metadata,
render_as_batch=True
) )
with context.begin_transaction(): with context.begin_transaction():
@@ -0,0 +1,32 @@
"""Store Matrix avatar URL in database
Revision ID: 3e3745baa458
Revises: dff56c93da8d
Create Date: 2020-06-15 14:32:10.454033
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3e3745baa458'
down_revision = 'dff56c93da8d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('portal', schema=None) as batch_op:
batch_op.add_column(sa.Column('avatar_url', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('portal', schema=None) as batch_op:
batch_op.drop_column('avatar_url')
# ### end Alembic commands ###
+3 -2
View File
@@ -17,7 +17,7 @@ from typing import Optional
from sqlalchemy import Column, Integer, String, Boolean, Text, func, sql from sqlalchemy import Column, Integer, String, Boolean, Text, func, sql
from mautrix.types import RoomID from mautrix.types import RoomID, ContentURI
from mautrix.util.db import Base from mautrix.util.db import Base
from ..types import TelegramID from ..types import TelegramID
@@ -33,7 +33,8 @@ class Portal(Base):
megagroup: bool = Column(Boolean) megagroup: bool = Column(Boolean)
# Matrix portal information # Matrix portal information
mxid: RoomID = Column(String, unique=True, nullable=True) mxid: Optional[RoomID] = Column(String, unique=True, nullable=True)
avatar_url: Optional[ContentURI] = Column(String, nullable=True)
encrypted: bool = Column(Boolean, nullable=False, server_default=sql.expression.false()) encrypted: bool = Column(Boolean, nullable=False, server_default=sql.expression.false())
config: str = Column(Text, nullable=True) config: str = Column(Text, nullable=True)
+17 -7
View File
@@ -31,7 +31,7 @@ from telethon.tl.types import (Channel, ChannelFull, Chat, ChatFull, ChatInviteE
from mautrix.errors import MatrixRequestError, IntentError from mautrix.errors import MatrixRequestError, IntentError
from mautrix.appservice import AppService, IntentAPI from mautrix.appservice import AppService, IntentAPI
from mautrix.types import (RoomID, RoomAlias, UserID, EventID, EventType, MessageEventContent, from mautrix.types import (RoomID, RoomAlias, UserID, EventID, EventType, MessageEventContent,
PowerLevelStateEventContent) PowerLevelStateEventContent, ContentURI)
from mautrix.util.simple_template import SimpleTemplate from mautrix.util.simple_template import SimpleTemplate
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
@@ -90,6 +90,7 @@ class BasePortal(ABC):
about: Optional[str] about: Optional[str]
photo_id: Optional[str] photo_id: Optional[str]
local_config: Dict[str, Any] local_config: Dict[str, Any]
avatar_url: Optional[ContentURI]
encrypted: bool encrypted: bool
deleted: bool deleted: bool
backfilling: bool backfilling: bool
@@ -108,8 +109,8 @@ class BasePortal(ABC):
mxid: Optional[RoomID] = None, username: Optional[str] = None, mxid: Optional[RoomID] = None, username: Optional[str] = None,
megagroup: Optional[bool] = False, title: Optional[str] = None, megagroup: Optional[bool] = False, title: Optional[str] = None,
about: Optional[str] = None, photo_id: Optional[str] = None, about: Optional[str] = None, photo_id: Optional[str] = None,
local_config: Optional[str] = None, encrypted: Optional[bool] = False, local_config: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
db_instance: DBPortal = None) -> None: encrypted: Optional[bool] = False, db_instance: DBPortal = None) -> None:
self.mxid = mxid self.mxid = mxid
self.tgid = tgid self.tgid = tgid
self.tg_receiver = tg_receiver or tgid self.tg_receiver = tg_receiver or tgid
@@ -120,6 +121,7 @@ class BasePortal(ABC):
self.about = about self.about = about
self.photo_id = photo_id self.photo_id = photo_id
self.local_config = json.loads(local_config or "{}") self.local_config = json.loads(local_config or "{}")
self.avatar_url = avatar_url
self.encrypted = encrypted self.encrypted = encrypted
self._db_instance = db_instance self._db_instance = db_instance
self._main_intent = None self._main_intent = None
@@ -335,12 +337,14 @@ class BasePortal(ABC):
return DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type, return DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
mxid=self.mxid, username=self.username, megagroup=self.megagroup, mxid=self.mxid, username=self.username, megagroup=self.megagroup,
title=self.title, about=self.about, photo_id=self.photo_id, title=self.title, about=self.about, photo_id=self.photo_id,
config=json.dumps(self.local_config), encrypted=self.encrypted) config=json.dumps(self.local_config), avatar_url=self.avatar_url,
encrypted=self.encrypted)
def save(self) -> None: def save(self) -> None:
self.db_instance.edit(mxid=self.mxid, username=self.username, title=self.title, self.db_instance.edit(mxid=self.mxid, username=self.username, title=self.title,
about=self.about, photo_id=self.photo_id, megagroup=self.megagroup, about=self.about, photo_id=self.photo_id, megagroup=self.megagroup,
config=json.dumps(self.local_config), encrypted=self.encrypted) config=json.dumps(self.local_config), avatar_url=self.avatar_url,
encrypted=self.encrypted)
def delete(self) -> None: def delete(self) -> None:
try: try:
@@ -362,7 +366,8 @@ class BasePortal(ABC):
peer_type=db_portal.peer_type, mxid=db_portal.mxid, username=db_portal.username, peer_type=db_portal.peer_type, mxid=db_portal.mxid, username=db_portal.username,
megagroup=db_portal.megagroup, title=db_portal.title, about=db_portal.about, megagroup=db_portal.megagroup, title=db_portal.title, about=db_portal.about,
photo_id=db_portal.photo_id, local_config=db_portal.config, photo_id=db_portal.photo_id, local_config=db_portal.config,
encrypted=db_portal.encrypted, db_instance=db_portal) avatar_url=db_portal.avatar_url, encrypted=db_portal.encrypted,
db_instance=db_portal)
# endregion # endregion
# region Class instance lookup # region Class instance lookup
@@ -509,6 +514,10 @@ class BasePortal(ABC):
def _migrate_and_save_telegram(self, new_id: TelegramID) -> None: def _migrate_and_save_telegram(self, new_id: TelegramID) -> None:
pass pass
@abstractmethod
async def _update_bridge_info(self) -> None:
pass
@abstractmethod @abstractmethod
def handle_matrix_power_levels(self, sender: 'u.User', new_levels: Dict[UserID, int], def handle_matrix_power_levels(self, sender: 'u.User', new_levels: Dict[UserID, int],
old_levels: Dict[UserID, int], event_id: Optional[EventID] old_levels: Dict[UserID, int], event_id: Optional[EventID]
@@ -520,7 +529,8 @@ class BasePortal(ABC):
pass pass
@abstractmethod @abstractmethod
async def _send_delivery_receipt(self, event_id: EventID) -> None: async def _send_delivery_receipt(self, event_id: EventID, room_id: Optional[RoomID] = None
) -> None:
pass pass
# endregion # endregion
+6 -1
View File
@@ -500,13 +500,17 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
self.title = title self.title = title
self.save() self.save()
await self._send_delivery_receipt(event_id) await self._send_delivery_receipt(event_id)
await self._update_bridge_info()
async def handle_matrix_avatar(self, sender: 'u.User', url: ContentURI, event_id: EventID async def handle_matrix_avatar(self, sender: 'u.User', url: ContentURI, event_id: EventID
) -> None: ) -> None:
if self.peer_type not in ("chat", "channel"): if self.peer_type not in ("chat", "channel"):
# Invalid peer type # Invalid peer type
return return
elif self.avatar_url == url:
return
self.avatar_url = url
file = await self.main_intent.download_media(url) file = await self.main_intent.download_media(url)
mime = magic.from_buffer(file, mime=True) mime = magic.from_buffer(file, mime=True)
ext = sane_mimetypes.guess_extension(mime) ext = sane_mimetypes.guess_extension(mime)
@@ -529,6 +533,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
self.save() self.save()
break break
await self._send_delivery_receipt(event_id) await self._send_delivery_receipt(event_id)
await self._update_bridge_info()
async def handle_matrix_upgrade(self, sender: UserID, new_room: RoomID, event_id: EventID async def handle_matrix_upgrade(self, sender: UserID, new_room: RoomID, event_id: EventID
) -> None: ) -> None:
@@ -558,7 +563,7 @@ class PortalMatrix(BasePortal, MautrixBasePortal, ABC):
return return
await self.update_matrix_room(user, entity, direct=self.peer_type == "user") await self.update_matrix_room(user, entity, direct=self.peer_type == "user")
self.log.info(f"{sender} upgraded room from {old_room} to {self.mxid}") self.log.info(f"{sender} upgraded room from {old_room} to {self.mxid}")
await self._send_delivery_receipt(event_id) await self._send_delivery_receipt(event_id, room_id=old_room)
def migrate_and_save_matrix(self, new_id: RoomID) -> None: def migrate_and_save_matrix(self, new_id: RoomID) -> None:
try: try:
+49 -9
View File
@@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Optional, Tuple, Union, Callable, Awaitable, TYPE_CHECKING from typing import List, Optional, Tuple, Union, Dict, Any, TYPE_CHECKING
from abc import ABC from abc import ABC
import asyncio import asyncio
@@ -45,6 +45,9 @@ if TYPE_CHECKING:
config: Optional['Config'] = None config: Optional['Config'] = None
StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
class PortalMetadata(BasePortal, ABC): class PortalMetadata(BasePortal, ABC):
_room_create_lock: asyncio.Lock _room_create_lock: asyncio.Lock
@@ -226,6 +229,7 @@ class PortalMetadata(BasePortal, ABC):
changed = await self._update_avatar(user, entity.photo) or changed changed = await self._update_avatar(user, entity.photo) or changed
if changed: if changed:
self.save() self.save()
await self._update_bridge_info()
if self.sync_matrix_state: if self.sync_matrix_state:
await self.sync_matrix_members() await self.sync_matrix_members()
@@ -253,6 +257,38 @@ class PortalMetadata(BasePortal, ABC):
except Exception: except Exception:
self.log.exception("Fatal error creating Matrix room") self.log.exception("Fatal error creating Matrix room")
@property
def bridge_info_state_key(self) -> str:
return f"net.maunium.telegram://telegram/{self.tgid}"
@property
def bridge_info(self) -> Dict[str, Any]:
return {
"bridgebot": self.az.bot_mxid,
"creator": self.main_intent.mxid,
"protocol": {
"id": "telegram",
"displayname": "Telegram",
"avatar_url": config["appservice.bot_avatar"],
},
"channel": {
"id": str(self.tgid),
"displayname": self.title,
"avatar_url": self.avatar_url,
}
}
async def _update_bridge_info(self) -> None:
try:
self.log.debug("Updating bridge info...")
await self.main_intent.send_state_event(self.mxid, StateBridge,
self.bridge_info, self.bridge_info_state_key)
# TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec
await self.main_intent.send_state_event(self.mxid, StateHalfShotBridge,
self.bridge_info, self.bridge_info_state_key)
except Exception:
self.log.warning("Failed to update bridge info", exc_info=True)
async def _create_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User], async def _create_matrix_room(self, user: 'AbstractUser', entity: Union[TypeChat, User],
invites: InviteList) -> Optional[RoomID]: invites: InviteList) -> Optional[RoomID]:
direct = self.peer_type == "user" direct = self.peer_type == "user"
@@ -333,14 +369,14 @@ class PortalMetadata(BasePortal, ABC):
"type": EventType.ROOM_POWER_LEVELS.serialize(), "type": EventType.ROOM_POWER_LEVELS.serialize(),
"content": power_levels.serialize(), "content": power_levels.serialize(),
}, { }, {
"type": "m.bridge", "type": str(StateBridge),
"state_key": f"net.maunium.telegram://telegram/{self.tgid}", "state_key": self.bridge_info_state_key,
"content": bridge_info "content": self.bridge_info,
}, { }, {
# TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec # TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec
"type": "uk.half-shot.bridge", "type": str(StateHalfShotBridge),
"state_key": f"net.maunium.telegram://telegram/{self.tgid}", "state_key": self.bridge_info_state_key,
"content": bridge_info "content": self.bridge_info,
}] }]
if config["bridge.encryption.default"] and self.matrix.e2ee: if config["bridge.encryption.default"] and self.matrix.e2ee:
self.encrypted = True self.encrypted = True
@@ -620,6 +656,7 @@ class PortalMetadata(BasePortal, ABC):
if changed: if changed:
self.save() self.save()
await self._update_bridge_info()
async def _update_username(self, username: str, save: bool = False) -> bool: async def _update_username(self, username: str, save: bool = False) -> bool:
if self.username == username: if self.username == username:
@@ -702,6 +739,7 @@ class PortalMetadata(BasePortal, ABC):
await self._try_set_state(sender, EventType.ROOM_AVATAR, await self._try_set_state(sender, EventType.ROOM_AVATAR,
RoomAvatarStateEventContent(url=None)) RoomAvatarStateEventContent(url=None))
self.photo_id = "" self.photo_id = ""
self.avatar_url = None
if save: if save:
self.save() self.save()
return True return True
@@ -710,6 +748,7 @@ class PortalMetadata(BasePortal, ABC):
await self._try_set_state(sender, EventType.ROOM_AVATAR, await self._try_set_state(sender, EventType.ROOM_AVATAR,
RoomAvatarStateEventContent(url=file.mxc)) RoomAvatarStateEventContent(url=file.mxc))
self.photo_id = photo_id self.photo_id = photo_id
self.avatar_url = file.mxc
if save: if save:
self.save() self.save()
return True return True
@@ -762,10 +801,11 @@ class PortalMetadata(BasePortal, ABC):
# endregion # endregion
async def _send_delivery_receipt(self, event_id: EventID) -> None: async def _send_delivery_receipt(self, event_id: EventID, room_id: Optional[RoomID] = None
) -> None:
if event_id and config["bridge.delivery_receipts"]: if event_id and config["bridge.delivery_receipts"]:
try: try:
await self.az.intent.mark_read(self.mxid, event_id) await self.az.intent.mark_read(room_id or self.mxid, event_id)
except Exception: except Exception:
self.log.exception("Failed to send delivery receipt for %s", event_id) self.log.exception("Failed to send delivery receipt for %s", event_id)
+4
View File
@@ -555,10 +555,13 @@ class PortalTelegram(BasePortal, ABC):
return return
if isinstance(action, MessageActionChatEditTitle): if isinstance(action, MessageActionChatEditTitle):
await self._update_title(action.title, sender=sender, save=True) await self._update_title(action.title, sender=sender, save=True)
await self._update_bridge_info()
elif isinstance(action, MessageActionChatEditPhoto): elif isinstance(action, MessageActionChatEditPhoto):
await self._update_avatar(source, action.photo, sender=sender, save=True) await self._update_avatar(source, action.photo, sender=sender, save=True)
await self._update_bridge_info()
elif isinstance(action, MessageActionChatDeletePhoto): elif isinstance(action, MessageActionChatDeletePhoto):
await self._update_avatar(source, ChatPhotoEmpty(), sender=sender, save=True) await self._update_avatar(source, ChatPhotoEmpty(), sender=sender, save=True)
await self._update_bridge_info()
elif isinstance(action, MessageActionChatAddUser): elif isinstance(action, MessageActionChatAddUser):
for user_id in action.users: for user_id in action.users:
await self._add_telegram_user(TelegramID(user_id), source) await self._add_telegram_user(TelegramID(user_id), source)
@@ -572,6 +575,7 @@ class PortalTelegram(BasePortal, ABC):
# TODO encrypt # TODO encrypt
await sender.intent_for(self).send_emote(self.mxid, await sender.intent_for(self).send_emote(self.mxid,
"upgraded this group to a supergroup.") "upgraded this group to a supergroup.")
await self._update_bridge_info()
elif isinstance(action, MessageActionGameScore): elif isinstance(action, MessageActionGameScore):
# TODO handle game score # TODO handle game score
pass pass
+1 -1
View File
@@ -4,6 +4,6 @@ ruamel.yaml>=0.15.35,<0.17
python-magic>=0.4,<0.5 python-magic>=0.4,<0.5
commonmark>=0.8,<0.10 commonmark>=0.8,<0.10
aiohttp>=3,<4 aiohttp>=3,<4
mautrix>=0.5.2,<0.6 mautrix>=0.5.5,<0.6
telethon>=1.13,<1.15 telethon>=1.13,<1.15
telethon-session-sqlalchemy>=0.2.14,<0.3 telethon-session-sqlalchemy>=0.2.14,<0.3