Add missing type hints and fix most type errors except for Optionals.
This commit is contained in:
+106
-80
@@ -14,17 +14,19 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
|
||||
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
|
||||
from difflib import SequenceMatcher
|
||||
import re
|
||||
import logging
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
from telethon.tl.types import UserProfilePhoto
|
||||
from telethon.tl.types import UserProfilePhoto, User, FileLocation
|
||||
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
|
||||
|
||||
from .types import MatrixUserId, TelegramId
|
||||
from .db import Puppet as DBPuppet
|
||||
from . import util
|
||||
|
||||
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
|
||||
from .matrix import MatrixHandler
|
||||
from .config import Config
|
||||
from .context import Context
|
||||
from . import user as u
|
||||
from .abstract_user import AbstractUser
|
||||
|
||||
|
||||
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
|
||||
|
||||
config = None # type: Config
|
||||
|
||||
@@ -45,85 +52,98 @@ class Puppet:
|
||||
mxid_regex = None # type: Pattern
|
||||
username_template = None # type: str
|
||||
hs_domain = None # type: str
|
||||
cache = {} # type: Dict[str, Puppet]
|
||||
cache = {} # type: Dict[TelegramId, Puppet]
|
||||
by_custom_mxid = {} # type: Dict[str, Puppet]
|
||||
|
||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||
is_registered=False, db_instance=None) -> None:
|
||||
self.id = id
|
||||
self.access_token = access_token
|
||||
self.custom_mxid = custom_mxid
|
||||
self.is_real_user = self.custom_mxid and self.access_token
|
||||
self.default_mxid = self.get_mxid_from_id(self.id)
|
||||
self.mxid = self.custom_mxid or self.default_mxid
|
||||
def __init__(self,
|
||||
id: TelegramId,
|
||||
access_token: Optional[str] = None,
|
||||
custom_mxid: Optional[MatrixUserId] = None,
|
||||
username: Optional[str] = None,
|
||||
displayname: Optional[str] = None,
|
||||
displayname_source: Optional[TelegramId] = None,
|
||||
photo_id: Optional[str] = None,
|
||||
is_bot: bool = False,
|
||||
is_registered: bool = False,
|
||||
db_instance: Optional[DBPuppet] = None) -> None:
|
||||
self.id = id # type: TelegramId
|
||||
self.access_token = access_token # type: Optional[str]
|
||||
self.custom_mxid = custom_mxid # type: Optional[MatrixUserId]
|
||||
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserId
|
||||
|
||||
self.username = username
|
||||
self.displayname = displayname
|
||||
self.displayname_source = displayname_source
|
||||
self.photo_id = photo_id
|
||||
self.is_bot = is_bot
|
||||
self.is_registered = is_registered
|
||||
self._db_instance = db_instance
|
||||
self.username = username # type: Optional[str]
|
||||
self.displayname = displayname # type: Optional[str]
|
||||
self.displayname_source = displayname_source # type: Optional[TelegramId]
|
||||
self.photo_id = photo_id # type: Optional[str]
|
||||
self.is_bot = is_bot # type: bool
|
||||
self.is_registered = is_registered # type: bool
|
||||
self._db_instance = db_instance # type: Optional[DBPuppet]
|
||||
|
||||
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
|
||||
self.intent = None # type: IntentAPI
|
||||
self.refresh_intents()
|
||||
self.intent = self._fresh_intent() # type: IntentAPI
|
||||
|
||||
self.cache[id] = self
|
||||
if self.custom_mxid:
|
||||
self.by_custom_mxid[self.custom_mxid] = self
|
||||
|
||||
@property
|
||||
def tgid(self) -> None:
|
||||
def mxid(self):
|
||||
return self.custom_mxid or self.default_mxid
|
||||
|
||||
@property
|
||||
def tgid(self) -> TelegramId:
|
||||
return self.id
|
||||
|
||||
@property
|
||||
def is_real_user(self) -> bool:
|
||||
""" Is True when the puppet is a real Matrix user. """
|
||||
return bool(self.custom_mxid and self.access_token)
|
||||
|
||||
@staticmethod
|
||||
async def is_logged_in() -> None:
|
||||
async def is_logged_in() -> bool:
|
||||
""" Is True if the puppet is logged in. """
|
||||
return True
|
||||
|
||||
# region Custom puppet management
|
||||
def refresh_intents(self) -> None:
|
||||
self.is_real_user = self.custom_mxid and self.access_token
|
||||
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
def _fresh_intent(self) -> IntentAPI:
|
||||
return (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
|
||||
async def switch_mxid(self, access_token, mxid) -> None:
|
||||
async def switch_mxid(self, access_token: str, mxid: MatrixUserId) -> PuppetError:
|
||||
prev_mxid = self.custom_mxid
|
||||
self.custom_mxid = mxid
|
||||
self.access_token = access_token
|
||||
self.refresh_intents()
|
||||
self.intent = self._fresh_intent()
|
||||
|
||||
err = await self.init_custom_mxid()
|
||||
if err != 0:
|
||||
if err != PuppetError.Success:
|
||||
return err
|
||||
|
||||
try:
|
||||
del self.by_custom_mxid[prev_mxid]
|
||||
del self.by_custom_mxid[prev_mxid] # type: ignore
|
||||
except KeyError:
|
||||
pass
|
||||
self.mxid = self.custom_mxid or self.default_mxid
|
||||
if self.mxid != self.default_mxid:
|
||||
self.by_custom_mxid[self.mxid] = self
|
||||
await self.leave_rooms_with_default_user()
|
||||
self.save()
|
||||
return 0
|
||||
return PuppetError.Success
|
||||
|
||||
async def init_custom_mxid(self) -> None:
|
||||
async def init_custom_mxid(self) -> PuppetError:
|
||||
if not self.is_real_user:
|
||||
return 0
|
||||
return PuppetError.Success
|
||||
|
||||
mxid = await self.intent.whoami()
|
||||
if not mxid or mxid != self.custom_mxid:
|
||||
self.custom_mxid = None
|
||||
self.access_token = None
|
||||
self.refresh_intents()
|
||||
self.intent = self._fresh_intent()
|
||||
if mxid != self.custom_mxid:
|
||||
return 2
|
||||
return 1
|
||||
return PuppetError.OnlyLoginSelf
|
||||
return PuppetError.InvalidAccessToken
|
||||
if config["bridge.sync_with_custom_puppets"]:
|
||||
asyncio.ensure_future(self.sync(), loop=self.loop)
|
||||
return 0
|
||||
return PuppetError.Success
|
||||
|
||||
async def leave_rooms_with_default_user(self) -> None:
|
||||
for room_id in await self.default_mxid_intent.get_joined_rooms():
|
||||
@@ -159,7 +179,7 @@ class Puppet:
|
||||
},
|
||||
})
|
||||
|
||||
def filter_events(self, events) -> None:
|
||||
def filter_events(self, events: List[Dict]) -> List:
|
||||
new_events = []
|
||||
for event in events:
|
||||
evt_type = event.get("type", None)
|
||||
@@ -186,18 +206,18 @@ class Puppet:
|
||||
new_events.append(event)
|
||||
return new_events
|
||||
|
||||
def handle_sync(self, presence, ephemeral) -> None:
|
||||
presence = [self.mx.try_handle_event(event) for event in presence]
|
||||
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
|
||||
presence_events = [self.mx.try_handle_event(event) for event in presence]
|
||||
|
||||
for room_id, events in ephemeral.items():
|
||||
for event in events:
|
||||
event["room_id"] = room_id
|
||||
|
||||
ephemeral = [self.mx.try_handle_event(event)
|
||||
for events in ephemeral.values()
|
||||
for event in self.filter_events(events)]
|
||||
ephemeral_events = [self.mx.try_handle_event(event)
|
||||
for events in ephemeral.values()
|
||||
for event in self.filter_events(events)]
|
||||
|
||||
events = ephemeral + presence
|
||||
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
|
||||
coro = asyncio.gather(*events, loop=self.loop)
|
||||
asyncio.ensure_future(coro, loop=self.loop)
|
||||
|
||||
@@ -220,13 +240,14 @@ class Puppet:
|
||||
while access_token_at_start == self.access_token:
|
||||
try:
|
||||
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
|
||||
set_presence="offline")
|
||||
set_presence="offline") # type: Dict
|
||||
errors = 0
|
||||
if next_batch is not None:
|
||||
presence = sync_resp.get("presence", {}).get("events", [])
|
||||
presence = sync_resp.get("presence", {}).get("events", []) # type: List
|
||||
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
|
||||
for room, data
|
||||
in sync_resp.get("rooms", {}).get("join", {}).items()}
|
||||
in sync_resp.get("rooms", {}).get("join", {}).items()
|
||||
} # type: Dict
|
||||
self.handle_sync(presence, ephemeral)
|
||||
next_batch = sync_resp.get("next_batch", None)
|
||||
except MatrixRequestError as e:
|
||||
@@ -241,19 +262,19 @@ class Puppet:
|
||||
# region DB conversion
|
||||
|
||||
@property
|
||||
def db_instance(self) -> None:
|
||||
def db_instance(self) -> DBPuppet:
|
||||
if not self._db_instance:
|
||||
self._db_instance = self.new_db_instance()
|
||||
return self._db_instance
|
||||
|
||||
def new_db_instance(self) -> None:
|
||||
def new_db_instance(self) -> DBPuppet:
|
||||
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
|
||||
username=self.username, displayname=self.displayname,
|
||||
displayname_source=self.displayname_source, photo_id=self.photo_id,
|
||||
is_bot=self.is_bot, matrix_registered=self.is_registered)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_puppet) -> None:
|
||||
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
|
||||
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
|
||||
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
|
||||
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
|
||||
@@ -272,16 +293,16 @@ class Puppet:
|
||||
|
||||
# endregion
|
||||
# region Info updating
|
||||
def similarity(self, query) -> None:
|
||||
def similarity(self, query: str) -> int:
|
||||
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
|
||||
if self.username else 0)
|
||||
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
|
||||
if self.displayname else 0)
|
||||
similarity = max(username_similarity, displayname_similarity)
|
||||
return round(similarity * 1000) / 10
|
||||
return int(round(similarity * 1000) / 10)
|
||||
|
||||
@staticmethod
|
||||
def get_displayname(info, enable_format=True) -> None:
|
||||
def get_displayname(info: User, enable_format: bool = True) -> str:
|
||||
data = {
|
||||
"phone number": info.phone if hasattr(info, "phone") else None,
|
||||
"username": info.username,
|
||||
@@ -308,7 +329,7 @@ class Puppet:
|
||||
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
|
||||
displayname=name)
|
||||
|
||||
async def update_info(self, source, info) -> None:
|
||||
async def update_info(self, source: 'AbstractUser', info: User) -> None:
|
||||
changed = False
|
||||
if self.username != info.username:
|
||||
self.username = info.username
|
||||
@@ -323,24 +344,26 @@ class Puppet:
|
||||
if changed:
|
||||
self.save()
|
||||
|
||||
async def update_displayname(self, source, info) -> None:
|
||||
async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
|
||||
ignore_source = (not source.is_relaybot
|
||||
and self.displayname_source is not None
|
||||
and self.displayname_source != source.tgid)
|
||||
if ignore_source:
|
||||
return
|
||||
return False
|
||||
|
||||
displayname = self.get_displayname(info)
|
||||
if displayname != self.displayname:
|
||||
await self.default_mxid_intent.set_display_name(displayname)
|
||||
self.displayname = displayname
|
||||
self.displayname_source = source.tgid
|
||||
self.displayname_source = TelegramId(source.tgid)
|
||||
return True
|
||||
elif source.is_relaybot or self.displayname_source is None:
|
||||
self.displayname_source = source.tgid
|
||||
self.displayname_source = TelegramId(source.tgid)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def update_avatar(self, source, photo) -> None:
|
||||
async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
|
||||
photo_id = f"{photo.volume_id}-{photo.local_id}"
|
||||
if self.photo_id != photo_id:
|
||||
file = await util.transfer_file_to_matrix(self.db, source.client,
|
||||
@@ -355,7 +378,7 @@ class Puppet:
|
||||
# region Getters
|
||||
|
||||
@classmethod
|
||||
def get(cls, tgid, create=True) -> "Optional[Puppet]":
|
||||
def get(cls, tgid: TelegramId, create: bool = True) -> Optional['Puppet']:
|
||||
try:
|
||||
return cls.cache[tgid]
|
||||
except KeyError:
|
||||
@@ -374,12 +397,15 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
|
||||
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['Puppet']:
|
||||
tgid = cls.get_id_from_mxid(mxid)
|
||||
return cls.get(tgid, create) if tgid else None
|
||||
if tgid:
|
||||
return cls.get(tgid, create)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid) -> None:
|
||||
def get_by_custom_mxid(cls, mxid: MatrixUserId) -> Optional['Puppet']:
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -396,25 +422,25 @@ class Puppet:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_with_custom_mxid(cls) -> None:
|
||||
def get_all_with_custom_mxid(cls) -> List['Puppet']:
|
||||
return [cls.by_custom_mxid[puppet.mxid]
|
||||
if puppet.custom_mxid in cls.by_custom_mxid
|
||||
else cls.from_db(puppet)
|
||||
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid) -> None:
|
||||
def get_id_from_mxid(cls, mxid: MatrixUserId) -> Optional[TelegramId]:
|
||||
match = cls.mxid_regex.match(mxid)
|
||||
if match:
|
||||
return int(match.group(1))
|
||||
return TelegramId(int(match.group(1)))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_mxid_from_id(cls, tgid) -> None:
|
||||
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
|
||||
def get_mxid_from_id(cls, tgid: TelegramId) -> MatrixUserId:
|
||||
return MatrixUserId(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username) -> "Optional[Puppet]":
|
||||
def find_by_username(cls, username: str) -> Optional['Puppet']:
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -422,14 +448,14 @@ class Puppet:
|
||||
if puppet.username and puppet.username.lower() == username.lower():
|
||||
return puppet
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
|
||||
if dbpuppet:
|
||||
return cls.from_db(dbpuppet)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
|
||||
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
|
||||
if not displayname:
|
||||
return None
|
||||
|
||||
@@ -437,17 +463,17 @@ class Puppet:
|
||||
if puppet.displayname and puppet.displayname == displayname:
|
||||
return puppet
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
|
||||
if puppet:
|
||||
return cls.from_db(puppet)
|
||||
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
|
||||
if dbpuppet:
|
||||
return cls.from_db(dbpuppet)
|
||||
|
||||
return None
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context: "Context") -> List[Awaitable[int]]:
|
||||
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
|
||||
global config
|
||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
|
||||
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
|
||||
Puppet.mx = context.mx
|
||||
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
|
||||
Reference in New Issue
Block a user