Sync own read marker to Matrix when backfilling chats

This commit is contained in:
Tulir Asokan
2021-03-22 13:51:22 +02:00
parent 2e8d612078
commit fa35ed1cb6
+34 -13
View File
@@ -19,7 +19,7 @@ from collections import defaultdict
import logging import logging
import asyncio import asyncio
from telethon.tl.types import (TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage, PeerUser, from telethon.tl.types import (TypeUpdate, UpdateNewMessage, UpdateNewChannelMessage,
UpdateShortChatMessage, UpdateShortMessage, User as TLUser, Chat, UpdateShortChatMessage, UpdateShortMessage, User as TLUser, Chat,
ChatForbidden) ChatForbidden)
from telethon.tl.custom import Dialog from telethon.tl.custom import Dialog
@@ -35,7 +35,7 @@ from mautrix.util.logging import TraceLogger
from mautrix.util.opt_prometheus import Gauge from mautrix.util.opt_prometheus import Gauge
from .types import TelegramID from .types import TelegramID
from .db import User as DBUser, Portal as DBPortal from .db import User as DBUser, Portal as DBPortal, Message as DBMessage
from .abstract_user import AbstractUser from .abstract_user import AbstractUser
from . import portal as po, puppet as pu from . import portal as po, puppet as pu
@@ -376,6 +376,34 @@ class User(AbstractUser, BaseUser):
if portal.mxid if portal.mxid
} }
async def _sync_dialog(self, portal: po.Portal, dialog: Dialog, should_create: bool,
puppet: Optional[pu.Puppet]) -> None:
if portal.mxid:
try:
await portal.backfill(self, last_id=dialog.message.id)
except Exception:
self.log.exception(f"Error while backfilling {portal.tgid_log}")
try:
await portal.update_matrix_room(self, dialog.entity)
except Exception:
self.log.exception(f"Error while updating {portal.tgid_log}")
elif should_create:
try:
await portal.create_matrix_room(self, dialog.entity, invites=[self.mxid])
except Exception:
self.log.exception(f"Error while creating {portal.tgid_log}")
if portal.mxid and puppet and puppet.is_real_user:
tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid
if dialog.unread_count == 0:
# This is usually more reliable than finding a specific message
# e.g. if the last read message is a service message that isn't in the message db
last_read = DBMessage.find_last(portal.mxid, tg_space)
else:
last_read = DBMessage.get_one_by_tgid(portal.tgid, tg_space,
dialog.dialog.read_inbox_max_id)
if last_read:
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
async def sync_dialogs(self) -> None: async def sync_dialogs(self) -> None:
if self.is_bot: if self.is_bot:
return return
@@ -385,6 +413,7 @@ class User(AbstractUser, BaseUser):
index = 0 index = 0
self.log.debug(f"Syncing dialogs (update_limit={update_limit}, " self.log.debug(f"Syncing dialogs (update_limit={update_limit}, "
f"create_limit={create_limit})") f"create_limit={create_limit})")
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
dialog: Dialog dialog: Dialog
async for dialog in self.client.iter_dialogs(limit=update_limit, ignore_migrated=True, async for dialog in self.client.iter_dialogs(limit=update_limit, ignore_migrated=True,
archived=False): archived=False):
@@ -400,17 +429,9 @@ class User(AbstractUser, BaseUser):
continue continue
portal = po.Portal.get_by_entity(entity, receiver_id=self.tgid) portal = po.Portal.get_by_entity(entity, receiver_id=self.tgid)
self.portals[portal.tgid_full] = portal self.portals[portal.tgid_full] = portal
if portal.mxid: coro = self._sync_dialog(portal=portal, dialog=dialog, puppet=puppet,
update_task = portal.update_matrix_room(self, entity) should_create=not create_limit or index < create_limit)
backfill_task = portal.backfill(self, last_id=dialog.message.id) creators.append(self.loop.create_task(coro))
creators.append(self._catch(f"updating {portal.tgid_log}",
self.loop.create_task(update_task)))
creators.append(self._catch(f"backfilling {portal.tgid_log}",
self.loop.create_task(backfill_task)))
elif not create_limit or index < create_limit:
create_task = portal.create_matrix_room(self, entity, invites=[self.mxid])
creators.append(self._catch(f"creating {portal.tgid_log}",
self.loop.create_task(create_task)))
index += 1 index += 1
await self.save(portals=True) await self.save(portals=True)
await asyncio.gather(*creators) await asyncio.gather(*creators)