Handle getting logged out the same way in all cases

This commit is contained in:
Tulir Asokan
2023-01-03 21:45:24 +02:00
parent f1f0b86696
commit 9fae4f14d2
6 changed files with 91 additions and 35 deletions
+8 -4
View File
@@ -22,7 +22,7 @@ import logging
import platform import platform
import time import time
from telethon.errors import UnauthorizedError from telethon.errors import AuthKeyError, UnauthorizedError
from telethon.network import ( from telethon.network import (
Connection, Connection,
ConnectionTcpFull, ConnectionTcpFull,
@@ -235,14 +235,18 @@ class AbstractUser(ABC):
) )
self.client.add_event_handler(self._update_catch) self.client.add_event_handler(self._update_catch)
@abstractmethod
async def on_signed_out(self, err: UnauthorizedError | AuthKeyError) -> None:
pass
async def _telethon_update_error_callback(self, err: Exception) -> None: async def _telethon_update_error_callback(self, err: Exception) -> None:
if isinstance(err, (UnauthorizedError, AuthKeyError)):
asyncio.create_task(self.on_signed_out(err))
return
if self.config["telegram.exit_on_update_error"]: if self.config["telegram.exit_on_update_error"]:
self.log.critical(f"Stopping due to update handling error {type(err).__name__}") self.log.critical(f"Stopping due to update handling error {type(err).__name__}")
self.bridge.manual_stop(50) self.bridge.manual_stop(50)
else: else:
if isinstance(err, UnauthorizedError):
self.log.warning("Not recreating Telethon update loop")
return
self.log.info("Recreating Telethon update loop in 60 seconds") self.log.info("Recreating Telethon update loop in 60 seconds")
await asyncio.sleep(60) await asyncio.sleep(60)
self.log.debug("Now recreating Telethon update loop") self.log.debug("Now recreating Telethon update loop")
+10 -1
View File
@@ -19,7 +19,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Literal
import logging import logging
import time import time
from telethon.errors import ChannelInvalidError, ChannelPrivateError from telethon.errors import (
AuthKeyError,
ChannelInvalidError,
ChannelPrivateError,
UnauthorizedError,
)
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.patched import Message, MessageService from telethon.tl.patched import Message, MessageService
@@ -145,6 +150,10 @@ class Bot(AbstractUser):
await self.post_login() await self.post_login()
return self return self
async def on_signed_out(self, err: UnauthorizedError | AuthKeyError) -> None:
self.log.fatal("Relay bot got signed out, crashing bridge", exc_info=err)
self.bridge.manual_stop(51)
async def post_login(self) -> None: async def post_login(self) -> None:
await self.init_permissions() await self.init_permissions()
info = await self.client.get_me() info = await self.client.get_me()
+3 -3
View File
@@ -24,7 +24,7 @@ from asyncpg import Record
from attr import dataclass from attr import dataclass
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.async_db import Database from mautrix.util.async_db import Connection, Database
from ..types import TelegramID from ..types import TelegramID
@@ -169,8 +169,8 @@ class Backfill:
) )
@classmethod @classmethod
async def delete_all(cls, user_mxid: UserID) -> None: async def delete_all(cls, user_mxid: UserID, conn: Connection | None = None) -> None:
await cls.db.execute("DELETE FROM backfill_queue WHERE user_mxid=$1", user_mxid) await (conn or cls.db).execute("DELETE FROM backfill_queue WHERE user_mxid=$1", user_mxid)
@classmethod @classmethod
async def delete_for_portal(cls, tgid: int, tg_receiver: int) -> None: async def delete_for_portal(cls, tgid: int, tg_receiver: int) -> None:
+18 -3
View File
@@ -21,9 +21,10 @@ from asyncpg import Record
from attr import dataclass from attr import dataclass
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.async_db import Database, Scheme from mautrix.util.async_db import Connection, Database, Scheme
from ..types import TelegramID from ..types import TelegramID
from .backfill_queue import Backfill
fake_db = Database.create("") if TYPE_CHECKING else None fake_db = Database.create("") if TYPE_CHECKING else None
@@ -73,6 +74,20 @@ class User:
async def delete(self) -> None: async def delete(self) -> None:
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid) await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
async def remove_tgid(self) -> None:
async with self.db.acquire() as conn, conn.transaction():
if self.tgid:
await conn.execute('DELETE FROM contact WHERE "user"=$1', self.tgid)
await conn.execute('DELETE FROM user_portal WHERE "user"=$1', self.tgid)
await Backfill.delete_all(self.mxid, conn=conn)
self.tgid = None
self.tg_username = None
self.tg_phone = None
self.is_bot = False
self.is_premium = False
self.saved_contacts = 0
await self.save(conn=conn)
@property @property
def _values(self): def _values(self):
return ( return (
@@ -85,13 +100,13 @@ class User:
self.saved_contacts, self.saved_contacts,
) )
async def save(self) -> None: async def save(self, conn: Connection | None = None) -> None:
q = """ q = """
UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, is_premium=$6, UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, is_premium=$6,
saved_contacts=$7 saved_contacts=$7
WHERE mxid=$1 WHERE mxid=$1
""" """
await self.db.execute(q, *self._values) await (conn or self.db).execute(q, *self._values)
async def insert(self) -> None: async def insert(self) -> None:
q = """ q = """
+51 -23
View File
@@ -22,6 +22,7 @@ import time
from telethon.errors import ( from telethon.errors import (
AuthKeyDuplicatedError, AuthKeyDuplicatedError,
AuthKeyError,
RPCError, RPCError,
TakeoutInitDelayError, TakeoutInitDelayError,
UnauthorizedError, UnauthorizedError,
@@ -207,17 +208,30 @@ class User(DBUser, AbstractUser, BaseUser):
async with self._ensure_started_lock: async with self._ensure_started_lock:
return cast(User, await super().ensure_started(even_if_no_session)) return cast(User, await super().ensure_started(even_if_no_session))
async def on_signed_out(self, err: UnauthorizedError | AuthKeyError) -> None:
error_code = "tg-auth-error"
if isinstance(err, AuthKeyDuplicatedError):
error_code = "tg-auth-key-duplicated"
message = None
else:
message = str(err)
self.log.warning(f"User got signed out with {err}, deleting data...")
try:
await self.log_out(
state=BridgeStateEvent.BAD_CREDENTIALS,
error=error_code,
message=message,
delete=False,
)
except Exception:
self.log.exception("Error handling external logout")
async def start(self, delete_unless_authenticated: bool = False) -> User: async def start(self, delete_unless_authenticated: bool = False) -> User:
try: try:
await super().start() await super().start()
except AuthKeyDuplicatedError: except AuthKeyDuplicatedError as e:
self.log.warning("Got AuthKeyDuplicatedError in start()") self.log.warning("Got AuthKeyDuplicatedError in start()")
await self.push_bridge_state( await self.on_signed_out(e)
BridgeStateEvent.BAD_CREDENTIALS, error="tg-auth-key-duplicated"
)
await self.client.disconnect()
await self.client.session.delete()
self.client = None
if not delete_unless_authenticated: if not delete_unless_authenticated:
# The caller wants the client to be connected, so restart the connection. # The caller wants the client to be connected, so restart the connection.
await super().start() await super().start()
@@ -237,12 +251,7 @@ class User(DBUser, AbstractUser, BaseUser):
if delete_unless_authenticated or self.tgid: if delete_unless_authenticated or self.tgid:
self.log.error(f"Authorization error in start(): {type(e)}: {e}") self.log.error(f"Authorization error in start(): {type(e)}: {e}")
if self.tgid: if self.tgid:
await self.push_bridge_state( await self.on_signed_out(e)
BridgeStateEvent.BAD_CREDENTIALS,
error="tg-auth-error",
message=str(e),
ttl=3600,
)
except RPCError as e: except RPCError as e:
self.log.error(f"Unknown RPC error in start(): {type(e)}: {e}") self.log.error(f"Unknown RPC error in start(): {type(e)}: {e}")
if self.tgid: if self.tgid:
@@ -253,7 +262,7 @@ class User(DBUser, AbstractUser, BaseUser):
asyncio.create_task(self.post_login()) asyncio.create_task(self.post_login())
return self return self
# Not authenticated, delete data if necessary # Not authenticated, delete data if necessary
if delete_unless_authenticated: if delete_unless_authenticated and self.client is not None:
self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...") self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...")
await self.client.disconnect() await self.client.disconnect()
await self.client.session.delete() await self.client.session.delete()
@@ -567,7 +576,14 @@ class User(DBUser, AbstractUser, BaseUser):
except MatrixRequestError: except MatrixRequestError:
pass pass
async def log_out(self) -> bool: async def log_out(
self,
delete: bool = True,
do_logout: bool = True,
state: BridgeStateEvent = BridgeStateEvent.LOGGED_OUT,
error: str | None = None,
message: str | None = None,
) -> bool:
puppet = await pu.Puppet.get_by_tgid(self.tgid) puppet = await pu.Puppet.get_by_tgid(self.tgid)
if puppet.is_real_user: if puppet.is_real_user:
await puppet.switch_mxid(None, None) await puppet.switch_mxid(None, None)
@@ -575,19 +591,31 @@ class User(DBUser, AbstractUser, BaseUser):
await self.kick_from_portals() await self.kick_from_portals()
except Exception: except Exception:
self.log.exception("Failed to kick user from portals on logout") self.log.exception("Failed to kick user from portals on logout")
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
if self.tgid: if self.tgid:
try: try:
del self.by_tgid[self.tgid] del self.by_tgid[self.tgid]
except KeyError: except KeyError:
pass pass
self.tgid = None ok = False
ok = await self.client.log_out() if self.client is not None:
sess = self.client.session sess = self.client.session
await self.stop() # Try to send a logout request. If it succeeds, this also disconnects the client and
await sess.delete() # deletes the session, but we do those again later just to be safe.
await self.delete() if do_logout:
self.by_mxid.pop(self.mxid, None) ok = await self.client.log_out()
# Force-disconnect the client and set it to None
await self.stop()
await sess.delete()
# TODO send a management room notice for non-manual logouts?
await self.push_bridge_state(state, error=error, message=message)
if delete:
await self.delete()
self.by_mxid.pop(self.mxid, None)
self.log.info("User deleted")
else:
await self.remove_tgid()
self.log.info("User telegram ID cleared")
self._track_metric(METRIC_LOGGED_IN, False) self._track_metric(METRIC_LOGGED_IN, False)
return ok return ok
+1 -1
View File
@@ -5,7 +5,7 @@ aiohttp>=3,<4
yarl>=1,<2 yarl>=1,<2
mautrix>=0.18.8,<0.19 mautrix>=0.18.8,<0.19
#telethon>=1.25.4,<1.27 #telethon>=1.25.4,<1.27
tulir-telethon==1.27.0a4 tulir-telethon==1.27.0a5
asyncpg>=0.20,<0.28 asyncpg>=0.20,<0.28
mako>=1,<2 mako>=1,<2
setuptools setuptools