Handle getting logged out the same way in all cases
This commit is contained in:
@@ -22,7 +22,7 @@ import logging
|
|||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from telethon.errors import UnauthorizedError
|
from telethon.errors import AuthKeyError, UnauthorizedError
|
||||||
from telethon.network import (
|
from telethon.network import (
|
||||||
Connection,
|
Connection,
|
||||||
ConnectionTcpFull,
|
ConnectionTcpFull,
|
||||||
@@ -235,14 +235,18 @@ class AbstractUser(ABC):
|
|||||||
)
|
)
|
||||||
self.client.add_event_handler(self._update_catch)
|
self.client.add_event_handler(self._update_catch)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_signed_out(self, err: UnauthorizedError | AuthKeyError) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def _telethon_update_error_callback(self, err: Exception) -> None:
|
async def _telethon_update_error_callback(self, err: Exception) -> None:
|
||||||
|
if isinstance(err, (UnauthorizedError, AuthKeyError)):
|
||||||
|
asyncio.create_task(self.on_signed_out(err))
|
||||||
|
return
|
||||||
if self.config["telegram.exit_on_update_error"]:
|
if self.config["telegram.exit_on_update_error"]:
|
||||||
self.log.critical(f"Stopping due to update handling error {type(err).__name__}")
|
self.log.critical(f"Stopping due to update handling error {type(err).__name__}")
|
||||||
self.bridge.manual_stop(50)
|
self.bridge.manual_stop(50)
|
||||||
else:
|
else:
|
||||||
if isinstance(err, UnauthorizedError):
|
|
||||||
self.log.warning("Not recreating Telethon update loop")
|
|
||||||
return
|
|
||||||
self.log.info("Recreating Telethon update loop in 60 seconds")
|
self.log.info("Recreating Telethon update loop in 60 seconds")
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
self.log.debug("Now recreating Telethon update loop")
|
self.log.debug("Now recreating Telethon update loop")
|
||||||
|
|||||||
+10
-1
@@ -19,7 +19,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Literal
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from telethon.errors import ChannelInvalidError, ChannelPrivateError
|
from telethon.errors import (
|
||||||
|
AuthKeyError,
|
||||||
|
ChannelInvalidError,
|
||||||
|
ChannelPrivateError,
|
||||||
|
UnauthorizedError,
|
||||||
|
)
|
||||||
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
|
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
|
||||||
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
|
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
|
||||||
from telethon.tl.patched import Message, MessageService
|
from telethon.tl.patched import Message, MessageService
|
||||||
@@ -145,6 +150,10 @@ class Bot(AbstractUser):
|
|||||||
await self.post_login()
|
await self.post_login()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
async def on_signed_out(self, err: UnauthorizedError | AuthKeyError) -> None:
|
||||||
|
self.log.fatal("Relay bot got signed out, crashing bridge", exc_info=err)
|
||||||
|
self.bridge.manual_stop(51)
|
||||||
|
|
||||||
async def post_login(self) -> None:
|
async def post_login(self) -> None:
|
||||||
await self.init_permissions()
|
await self.init_permissions()
|
||||||
info = await self.client.get_me()
|
info = await self.client.get_me()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from asyncpg import Record
|
|||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from mautrix.types import UserID
|
||||||
from mautrix.util.async_db import Database
|
from mautrix.util.async_db import Connection, Database
|
||||||
|
|
||||||
from ..types import TelegramID
|
from ..types import TelegramID
|
||||||
|
|
||||||
@@ -169,8 +169,8 @@ class Backfill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_all(cls, user_mxid: UserID) -> None:
|
async def delete_all(cls, user_mxid: UserID, conn: Connection | None = None) -> None:
|
||||||
await cls.db.execute("DELETE FROM backfill_queue WHERE user_mxid=$1", user_mxid)
|
await (conn or cls.db).execute("DELETE FROM backfill_queue WHERE user_mxid=$1", user_mxid)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_for_portal(cls, tgid: int, tg_receiver: int) -> None:
|
async def delete_for_portal(cls, tgid: int, tg_receiver: int) -> None:
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ from asyncpg import Record
|
|||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from mautrix.types import UserID
|
||||||
from mautrix.util.async_db import Database, Scheme
|
from mautrix.util.async_db import Connection, Database, Scheme
|
||||||
|
|
||||||
from ..types import TelegramID
|
from ..types import TelegramID
|
||||||
|
from .backfill_queue import Backfill
|
||||||
|
|
||||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||||
|
|
||||||
@@ -73,6 +74,20 @@ class User:
|
|||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
|
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
|
||||||
|
|
||||||
|
async def remove_tgid(self) -> None:
|
||||||
|
async with self.db.acquire() as conn, conn.transaction():
|
||||||
|
if self.tgid:
|
||||||
|
await conn.execute('DELETE FROM contact WHERE "user"=$1', self.tgid)
|
||||||
|
await conn.execute('DELETE FROM user_portal WHERE "user"=$1', self.tgid)
|
||||||
|
await Backfill.delete_all(self.mxid, conn=conn)
|
||||||
|
self.tgid = None
|
||||||
|
self.tg_username = None
|
||||||
|
self.tg_phone = None
|
||||||
|
self.is_bot = False
|
||||||
|
self.is_premium = False
|
||||||
|
self.saved_contacts = 0
|
||||||
|
await self.save(conn=conn)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _values(self):
|
def _values(self):
|
||||||
return (
|
return (
|
||||||
@@ -85,13 +100,13 @@ class User:
|
|||||||
self.saved_contacts,
|
self.saved_contacts,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def save(self) -> None:
|
async def save(self, conn: Connection | None = None) -> None:
|
||||||
q = """
|
q = """
|
||||||
UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, is_premium=$6,
|
UPDATE "user" SET tgid=$2, tg_username=$3, tg_phone=$4, is_bot=$5, is_premium=$6,
|
||||||
saved_contacts=$7
|
saved_contacts=$7
|
||||||
WHERE mxid=$1
|
WHERE mxid=$1
|
||||||
"""
|
"""
|
||||||
await self.db.execute(q, *self._values)
|
await (conn or self.db).execute(q, *self._values)
|
||||||
|
|
||||||
async def insert(self) -> None:
|
async def insert(self) -> None:
|
||||||
q = """
|
q = """
|
||||||
|
|||||||
+51
-23
@@ -22,6 +22,7 @@ import time
|
|||||||
|
|
||||||
from telethon.errors import (
|
from telethon.errors import (
|
||||||
AuthKeyDuplicatedError,
|
AuthKeyDuplicatedError,
|
||||||
|
AuthKeyError,
|
||||||
RPCError,
|
RPCError,
|
||||||
TakeoutInitDelayError,
|
TakeoutInitDelayError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
@@ -207,17 +208,30 @@ class User(DBUser, AbstractUser, BaseUser):
|
|||||||
async with self._ensure_started_lock:
|
async with self._ensure_started_lock:
|
||||||
return cast(User, await super().ensure_started(even_if_no_session))
|
return cast(User, await super().ensure_started(even_if_no_session))
|
||||||
|
|
||||||
|
async def on_signed_out(self, err: UnauthorizedError | AuthKeyError) -> None:
|
||||||
|
error_code = "tg-auth-error"
|
||||||
|
if isinstance(err, AuthKeyDuplicatedError):
|
||||||
|
error_code = "tg-auth-key-duplicated"
|
||||||
|
message = None
|
||||||
|
else:
|
||||||
|
message = str(err)
|
||||||
|
self.log.warning(f"User got signed out with {err}, deleting data...")
|
||||||
|
try:
|
||||||
|
await self.log_out(
|
||||||
|
state=BridgeStateEvent.BAD_CREDENTIALS,
|
||||||
|
error=error_code,
|
||||||
|
message=message,
|
||||||
|
delete=False,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Error handling external logout")
|
||||||
|
|
||||||
async def start(self, delete_unless_authenticated: bool = False) -> User:
|
async def start(self, delete_unless_authenticated: bool = False) -> User:
|
||||||
try:
|
try:
|
||||||
await super().start()
|
await super().start()
|
||||||
except AuthKeyDuplicatedError:
|
except AuthKeyDuplicatedError as e:
|
||||||
self.log.warning("Got AuthKeyDuplicatedError in start()")
|
self.log.warning("Got AuthKeyDuplicatedError in start()")
|
||||||
await self.push_bridge_state(
|
await self.on_signed_out(e)
|
||||||
BridgeStateEvent.BAD_CREDENTIALS, error="tg-auth-key-duplicated"
|
|
||||||
)
|
|
||||||
await self.client.disconnect()
|
|
||||||
await self.client.session.delete()
|
|
||||||
self.client = None
|
|
||||||
if not delete_unless_authenticated:
|
if not delete_unless_authenticated:
|
||||||
# The caller wants the client to be connected, so restart the connection.
|
# The caller wants the client to be connected, so restart the connection.
|
||||||
await super().start()
|
await super().start()
|
||||||
@@ -237,12 +251,7 @@ class User(DBUser, AbstractUser, BaseUser):
|
|||||||
if delete_unless_authenticated or self.tgid:
|
if delete_unless_authenticated or self.tgid:
|
||||||
self.log.error(f"Authorization error in start(): {type(e)}: {e}")
|
self.log.error(f"Authorization error in start(): {type(e)}: {e}")
|
||||||
if self.tgid:
|
if self.tgid:
|
||||||
await self.push_bridge_state(
|
await self.on_signed_out(e)
|
||||||
BridgeStateEvent.BAD_CREDENTIALS,
|
|
||||||
error="tg-auth-error",
|
|
||||||
message=str(e),
|
|
||||||
ttl=3600,
|
|
||||||
)
|
|
||||||
except RPCError as e:
|
except RPCError as e:
|
||||||
self.log.error(f"Unknown RPC error in start(): {type(e)}: {e}")
|
self.log.error(f"Unknown RPC error in start(): {type(e)}: {e}")
|
||||||
if self.tgid:
|
if self.tgid:
|
||||||
@@ -253,7 +262,7 @@ class User(DBUser, AbstractUser, BaseUser):
|
|||||||
asyncio.create_task(self.post_login())
|
asyncio.create_task(self.post_login())
|
||||||
return self
|
return self
|
||||||
# Not authenticated, delete data if necessary
|
# Not authenticated, delete data if necessary
|
||||||
if delete_unless_authenticated:
|
if delete_unless_authenticated and self.client is not None:
|
||||||
self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...")
|
self.log.debug(f"Unauthenticated user {self.name} start()ed, deleting session...")
|
||||||
await self.client.disconnect()
|
await self.client.disconnect()
|
||||||
await self.client.session.delete()
|
await self.client.session.delete()
|
||||||
@@ -567,7 +576,14 @@ class User(DBUser, AbstractUser, BaseUser):
|
|||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def log_out(self) -> bool:
|
async def log_out(
|
||||||
|
self,
|
||||||
|
delete: bool = True,
|
||||||
|
do_logout: bool = True,
|
||||||
|
state: BridgeStateEvent = BridgeStateEvent.LOGGED_OUT,
|
||||||
|
error: str | None = None,
|
||||||
|
message: str | None = None,
|
||||||
|
) -> bool:
|
||||||
puppet = await pu.Puppet.get_by_tgid(self.tgid)
|
puppet = await pu.Puppet.get_by_tgid(self.tgid)
|
||||||
if puppet.is_real_user:
|
if puppet.is_real_user:
|
||||||
await puppet.switch_mxid(None, None)
|
await puppet.switch_mxid(None, None)
|
||||||
@@ -575,19 +591,31 @@ class User(DBUser, AbstractUser, BaseUser):
|
|||||||
await self.kick_from_portals()
|
await self.kick_from_portals()
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to kick user from portals on logout")
|
self.log.exception("Failed to kick user from portals on logout")
|
||||||
await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT)
|
|
||||||
if self.tgid:
|
if self.tgid:
|
||||||
try:
|
try:
|
||||||
del self.by_tgid[self.tgid]
|
del self.by_tgid[self.tgid]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
self.tgid = None
|
ok = False
|
||||||
ok = await self.client.log_out()
|
if self.client is not None:
|
||||||
sess = self.client.session
|
sess = self.client.session
|
||||||
await self.stop()
|
# Try to send a logout request. If it succeeds, this also disconnects the client and
|
||||||
await sess.delete()
|
# deletes the session, but we do those again later just to be safe.
|
||||||
await self.delete()
|
if do_logout:
|
||||||
self.by_mxid.pop(self.mxid, None)
|
ok = await self.client.log_out()
|
||||||
|
# Force-disconnect the client and set it to None
|
||||||
|
await self.stop()
|
||||||
|
await sess.delete()
|
||||||
|
|
||||||
|
# TODO send a management room notice for non-manual logouts?
|
||||||
|
await self.push_bridge_state(state, error=error, message=message)
|
||||||
|
if delete:
|
||||||
|
await self.delete()
|
||||||
|
self.by_mxid.pop(self.mxid, None)
|
||||||
|
self.log.info("User deleted")
|
||||||
|
else:
|
||||||
|
await self.remove_tgid()
|
||||||
|
self.log.info("User telegram ID cleared")
|
||||||
self._track_metric(METRIC_LOGGED_IN, False)
|
self._track_metric(METRIC_LOGGED_IN, False)
|
||||||
return ok
|
return ok
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -5,7 +5,7 @@ aiohttp>=3,<4
|
|||||||
yarl>=1,<2
|
yarl>=1,<2
|
||||||
mautrix>=0.18.8,<0.19
|
mautrix>=0.18.8,<0.19
|
||||||
#telethon>=1.25.4,<1.27
|
#telethon>=1.25.4,<1.27
|
||||||
tulir-telethon==1.27.0a4
|
tulir-telethon==1.27.0a5
|
||||||
asyncpg>=0.20,<0.28
|
asyncpg>=0.20,<0.28
|
||||||
mako>=1,<2
|
mako>=1,<2
|
||||||
setuptools
|
setuptools
|
||||||
|
|||||||
Reference in New Issue
Block a user