Maybe improve channel leave handling

This commit is contained in:
Tulir Asokan
2022-06-27 15:11:59 +03:00
parent 1d0e8c7e0c
commit 8f68801aa9
4 changed files with 44 additions and 15 deletions
+15
View File
@@ -38,6 +38,7 @@ from telethon.tl.types import (
PeerChat, PeerChat,
PeerUser, PeerUser,
TypeUpdate, TypeUpdate,
UpdateChannel,
UpdateChannelUserTyping, UpdateChannelUserTyping,
UpdateChatParticipantAdmin, UpdateChatParticipantAdmin,
UpdateChatParticipants, UpdateChatParticipants,
@@ -354,6 +355,8 @@ class AbstractUser(ABC):
await self.update_pinned_dialogs(update) await self.update_pinned_dialogs(update)
elif isinstance(update, UpdateNotifySettings): elif isinstance(update, UpdateNotifySettings):
await self.update_notify_settings(update) await self.update_notify_settings(update)
elif isinstance(update, UpdateChannel):
await self.update_channel(update)
else: else:
self.log.trace("Unhandled update: %s", update) self.log.trace("Unhandled update: %s", update)
@@ -584,6 +587,18 @@ class AbstractUser(ABC):
return return
await portal.handle_telegram_reactions(self, TelegramID(update.msg_id), update.reactions) await portal.handle_telegram_reactions(self, TelegramID(update.msg_id), update.reactions)
async def update_channel(self, update: UpdateChannel) -> None:
portal = await po.Portal.get_by_tgid(TelegramID(update.channel_id))
if not portal:
return
if getattr(update, "mau_telethon_is_leave", False):
self.log.debug("UpdateChannel has mau_telethon_is_leave, leaving portal")
await portal.delete_telegram_user(self.tgid, sender=None)
elif chan := getattr(update, "mau_channel", None):
self.log.debug("Updating channel info with data fetched by Telethon")
await portal.update_info(self, chan)
await portal.invite_to_matrix(self.mxid)
async def update_message(self, original_update: UpdateMessage) -> None: async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = await self.get_message_details(original_update) update, sender, portal = await self.get_message_details(original_update)
if not portal: if not portal:
+12 -9
View File
@@ -136,6 +136,10 @@ class PgSession(MemorySession):
q, self.session_id, entity_id, row.pts, row.qts, ts, row.seq, row.unread_count q, self.session_id, entity_id, row.pts, row.qts, ts, row.seq, row.unread_count
) )
async def delete_update_state(self, entity_id: int) -> None:
q = "DELETE FROM telethon_update_state WHERE session_id=$1 AND entity_id=$2"
await self.db.execute(q, self.session_id, entity_id)
async def get_update_states(self) -> Iterable[tuple[int, updates.State], ...]: async def get_update_states(self) -> Iterable[tuple[int, updates.State], ...]:
q = ( q = (
"SELECT entity_id, pts, qts, date, seq, unread_count FROM telethon_update_state " "SELECT entity_id, pts, qts, date, seq, unread_count FROM telethon_update_state "
@@ -196,25 +200,24 @@ class PgSession(MemorySession):
async def _select_entity( async def _select_entity(
self, constraint: str, *args: str | int | tuple[int, ...] self, constraint: str, *args: str | int | tuple[int, ...]
) -> tuple[int, int] | None: ) -> tuple[int, int] | None:
row = await self.db.fetchrow( q = f"SELECT id, hash FROM telethon_entities WHERE session_id=$1 AND {constraint}"
f"SELECT id, hash FROM telethon_entities WHERE {constraint}", *args row = await self.db.fetchrow(q, self.session_id, *args)
)
if row is None: if row is None:
return None return None
return row["id"], row["hash"] return row["id"], row["hash"]
async def get_entity_rows_by_phone(self, key: str | int) -> tuple[int, int] | None: async def get_entity_rows_by_phone(self, key: str | int) -> tuple[int, int] | None:
return await self._select_entity("phone=$1", str(key)) return await self._select_entity("phone=$2", str(key))
async def get_entity_rows_by_username(self, key: str) -> tuple[int, int] | None: async def get_entity_rows_by_username(self, key: str) -> tuple[int, int] | None:
return await self._select_entity("username=$1", key) return await self._select_entity("username=$2", key)
async def get_entity_rows_by_name(self, key: str) -> tuple[int, int] | None: async def get_entity_rows_by_name(self, key: str) -> tuple[int, int] | None:
return await self._select_entity("name=$1", key) return await self._select_entity("name=$2", key)
async def get_entity_rows_by_id(self, key: int, exact: bool = True) -> tuple[int, int] | None: async def get_entity_rows_by_id(self, key: int, exact: bool = True) -> tuple[int, int] | None:
if exact: if exact:
return await self._select_entity("id=$1", key) return await self._select_entity("id=$2", key)
ids = ( ids = (
utils.get_peer_id(PeerUser(key)), utils.get_peer_id(PeerUser(key)),
@@ -222,6 +225,6 @@ class PgSession(MemorySession):
utils.get_peer_id(PeerChannel(key)), utils.get_peer_id(PeerChannel(key)),
) )
if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH): if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
return await self._select_entity("id=ANY($1)", ids) return await self._select_entity("id=ANY($2)", ids)
else: else:
return await self._select_entity(f"id IN ($1, $2, $3)", *ids) return await self._select_entity(f"id IN ($2, $3, $4)", *ids)
+16 -5
View File
@@ -1008,7 +1008,13 @@ class Portal(DBPortal, BasePortal):
) -> None: ) -> None:
puppet = await p.Puppet.get_by_tgid(user_id) puppet = await p.Puppet.get_by_tgid(user_id)
if source: if source:
entity: User = await source.client.get_entity(PeerUser(user_id)) try:
entity: User = await source.client.get_entity(PeerUser(user_id))
except ValueError:
self.log.warning(
f"Couldn't get info of {user_id} through {source.tgid} to add them to the room"
)
return
await puppet.update_info(source, entity) await puppet.update_info(source, entity)
await puppet.intent_for(self).ensure_joined(self.mxid) await puppet.intent_for(self).ensure_joined(self.mxid)
@@ -1017,8 +1023,10 @@ class Portal(DBPortal, BasePortal):
await user.register_portal(self) await user.register_portal(self)
await self.invite_to_matrix(user.mxid) await self.invite_to_matrix(user.mxid)
async def _delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet) -> None: async def delete_telegram_user(self, user_id: TelegramID, sender: p.Puppet | None) -> None:
puppet = await p.Puppet.get_by_tgid(user_id) puppet = await p.Puppet.get_by_tgid(user_id)
if sender is None:
sender = puppet
user = await u.User.get_by_tgid(user_id) user = await u.User.get_by_tgid(user_id)
kick_message = ( kick_message = (
f"Kicked by {sender.displayname}" f"Kicked by {sender.displayname}"
@@ -1034,8 +1042,11 @@ class Portal(DBPortal, BasePortal):
self.mxid, puppet.mxid, extra_content=puppet_extra_content self.mxid, puppet.mxid, extra_content=puppet_extra_content
) )
except MForbidden: except MForbidden:
await self.main_intent.kick_user(self.mxid, puppet.mxid, kick_message) try:
else: await self.main_intent.kick_user(self.mxid, puppet.mxid, kick_message)
except MForbidden as e:
self.log.warning(f"Failed to kick {puppet.mxid}: {e}")
elif not await self.az.state_store.is_joined(self.mxid, puppet.intent_for(self).mxid):
await puppet.intent_for(self).leave_room(self.mxid, extra_content=puppet_extra_content) await puppet.intent_for(self).leave_room(self.mxid, extra_content=puppet_extra_content)
if user: if user:
await user.unregister_portal(*self.tgid_full) await user.unregister_portal(*self.tgid_full)
@@ -2827,7 +2838,7 @@ class Portal(DBPortal, BasePortal):
elif isinstance(action, (MessageActionChatJoinedByLink, MessageActionChatJoinedByRequest)): elif isinstance(action, (MessageActionChatJoinedByLink, MessageActionChatJoinedByRequest)):
await self._add_telegram_user(sender.id, source) await self._add_telegram_user(sender.id, source)
elif isinstance(action, MessageActionChatDeleteUser): elif isinstance(action, MessageActionChatDeleteUser):
await self._delete_telegram_user(TelegramID(action.user_id), sender) await self.delete_telegram_user(TelegramID(action.user_id), sender)
elif isinstance(action, MessageActionChatMigrateTo): elif isinstance(action, MessageActionChatMigrateTo):
await self._migrate_and_save_telegram(TelegramID(action.channel_id)) await self._migrate_and_save_telegram(TelegramID(action.channel_id))
# TODO encrypt # TODO encrypt
+1 -1
View File
@@ -5,7 +5,7 @@ aiohttp>=3,<4
yarl>=1,<2 yarl>=1,<2
mautrix>=0.16.10,<0.17 mautrix>=0.16.10,<0.17
#telethon>=1.24,<1.25 #telethon>=1.24,<1.25
tulir-telethon==1.25.0a17 tulir-telethon==1.25.0a19
asyncpg>=0.20,<0.26 asyncpg>=0.20,<0.26
mako>=1,<2 mako>=1,<2
setuptools setuptools