Add missing type hints and fix most type errors except for Optionals.

This commit is contained in:
Kai A. Hiller
2018-08-09 02:19:55 +02:00
parent 01e153662e
commit 0f8009b1e9
26 changed files with 505 additions and 384 deletions
+106 -80
View File
@@ -14,17 +14,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Awaitable, Pattern, Dict, List, TYPE_CHECKING
from typing import Awaitable, Coroutine, Dict, List, NewType, Optional, Pattern, TYPE_CHECKING
from difflib import SequenceMatcher
import re
import logging
import asyncio
from enum import Enum
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto
from telethon.tl.types import UserProfilePhoto, User, FileLocation
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
from .types import MatrixUserId, TelegramId
from .db import Puppet as DBPuppet
from . import util
@@ -32,6 +34,11 @@ if TYPE_CHECKING:
from .matrix import MatrixHandler
from .config import Config
from .context import Context
from . import user as u
from .abstract_user import AbstractUser
PuppetError = Enum('PuppetError', 'Success OnlyLoginSelf InvalidAccessToken')
config = None # type: Config
@@ -45,85 +52,98 @@ class Puppet:
mxid_regex = None # type: Pattern
username_template = None # type: str
hs_domain = None # type: str
cache = {} # type: Dict[str, Puppet]
cache = {} # type: Dict[TelegramId, Puppet]
by_custom_mxid = {} # type: Dict[str, Puppet]
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
is_registered=False, db_instance=None) -> None:
self.id = id
self.access_token = access_token
self.custom_mxid = custom_mxid
self.is_real_user = self.custom_mxid and self.access_token
self.default_mxid = self.get_mxid_from_id(self.id)
self.mxid = self.custom_mxid or self.default_mxid
def __init__(self,
id: TelegramId,
access_token: Optional[str] = None,
custom_mxid: Optional[MatrixUserId] = None,
username: Optional[str] = None,
displayname: Optional[str] = None,
displayname_source: Optional[TelegramId] = None,
photo_id: Optional[str] = None,
is_bot: bool = False,
is_registered: bool = False,
db_instance: Optional[DBPuppet] = None) -> None:
self.id = id # type: TelegramId
self.access_token = access_token # type: Optional[str]
self.custom_mxid = custom_mxid # type: Optional[MatrixUserId]
self.default_mxid = self.get_mxid_from_id(self.id) # type: MatrixUserId
self.username = username
self.displayname = displayname
self.displayname_source = displayname_source
self.photo_id = photo_id
self.is_bot = is_bot
self.is_registered = is_registered
self._db_instance = db_instance
self.username = username # type: Optional[str]
self.displayname = displayname # type: Optional[str]
self.displayname_source = displayname_source # type: Optional[TelegramId]
self.photo_id = photo_id # type: Optional[str]
self.is_bot = is_bot # type: bool
self.is_registered = is_registered # type: bool
self._db_instance = db_instance # type: Optional[DBPuppet]
self.default_mxid_intent = self.az.intent.user(self.default_mxid)
self.intent = None # type: IntentAPI
self.refresh_intents()
self.intent = self._fresh_intent() # type: IntentAPI
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
@property
def tgid(self) -> None:
def mxid(self):
return self.custom_mxid or self.default_mxid
@property
def tgid(self) -> TelegramId:
return self.id
@property
def is_real_user(self) -> bool:
""" Is True when the puppet is a real Matrix user. """
return bool(self.custom_mxid and self.access_token)
@staticmethod
async def is_logged_in() -> None:
async def is_logged_in() -> bool:
""" Is True if the puppet is logged in. """
return True
# region Custom puppet management
def refresh_intents(self) -> None:
self.is_real_user = self.custom_mxid and self.access_token
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
def _fresh_intent(self) -> IntentAPI:
return (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def switch_mxid(self, access_token, mxid) -> None:
async def switch_mxid(self, access_token: str, mxid: MatrixUserId) -> PuppetError:
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.refresh_intents()
self.intent = self._fresh_intent()
err = await self.init_custom_mxid()
if err != 0:
if err != PuppetError.Success:
return err
try:
del self.by_custom_mxid[prev_mxid]
del self.by_custom_mxid[prev_mxid] # type: ignore
except KeyError:
pass
self.mxid = self.custom_mxid or self.default_mxid
if self.mxid != self.default_mxid:
self.by_custom_mxid[self.mxid] = self
await self.leave_rooms_with_default_user()
self.save()
return 0
return PuppetError.Success
async def init_custom_mxid(self) -> None:
async def init_custom_mxid(self) -> PuppetError:
if not self.is_real_user:
return 0
return PuppetError.Success
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.refresh_intents()
self.intent = self._fresh_intent()
if mxid != self.custom_mxid:
return 2
return 1
return PuppetError.OnlyLoginSelf
return PuppetError.InvalidAccessToken
if config["bridge.sync_with_custom_puppets"]:
asyncio.ensure_future(self.sync(), loop=self.loop)
return 0
return PuppetError.Success
async def leave_rooms_with_default_user(self) -> None:
for room_id in await self.default_mxid_intent.get_joined_rooms():
@@ -159,7 +179,7 @@ class Puppet:
},
})
def filter_events(self, events) -> None:
def filter_events(self, events: List[Dict]) -> List:
new_events = []
for event in events:
evt_type = event.get("type", None)
@@ -186,18 +206,18 @@ class Puppet:
new_events.append(event)
return new_events
def handle_sync(self, presence, ephemeral) -> None:
presence = [self.mx.try_handle_event(event) for event in presence]
def handle_sync(self, presence: List, ephemeral: Dict) -> None:
presence_events = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
ephemeral_events = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in self.filter_events(events)]
events = ephemeral + presence
events = ephemeral_events + presence_events # List[Callable[[int], Awaitable[None]]]
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
@@ -220,13 +240,14 @@ class Puppet:
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch,
set_presence="offline")
set_presence="offline") # type: Dict
errors = 0
if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", [])
presence = sync_resp.get("presence", {}).get("events", []) # type: List
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()}
in sync_resp.get("rooms", {}).get("join", {}).items()
} # type: Dict
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e:
@@ -241,19 +262,19 @@ class Puppet:
# region DB conversion
@property
def db_instance(self) -> None:
def db_instance(self) -> DBPuppet:
if not self._db_instance:
self._db_instance = self.new_db_instance()
return self._db_instance
def new_db_instance(self) -> None:
def new_db_instance(self) -> DBPuppet:
return DBPuppet(id=self.id, access_token=self.access_token, custom_mxid=self.custom_mxid,
username=self.username, displayname=self.displayname,
displayname_source=self.displayname_source, photo_id=self.photo_id,
is_bot=self.is_bot, matrix_registered=self.is_registered)
@classmethod
def from_db(cls, db_puppet) -> None:
def from_db(cls, db_puppet: DBPuppet) -> 'Puppet':
return Puppet(db_puppet.id, db_puppet.access_token, db_puppet.custom_mxid,
db_puppet.username, db_puppet.displayname, db_puppet.displayname_source,
db_puppet.photo_id, db_puppet.is_bot, db_puppet.matrix_registered,
@@ -272,16 +293,16 @@ class Puppet:
# endregion
# region Info updating
def similarity(self, query) -> None:
def similarity(self, query: str) -> int:
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0)
similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10
return int(round(similarity * 1000) / 10)
@staticmethod
def get_displayname(info, enable_format=True) -> None:
def get_displayname(info: User, enable_format: bool = True) -> str:
data = {
"phone number": info.phone if hasattr(info, "phone") else None,
"username": info.username,
@@ -308,7 +329,7 @@ class Puppet:
return config.get("bridge.displayname_template", "{displayname} (Telegram)").format(
displayname=name)
async def update_info(self, source, info) -> None:
async def update_info(self, source: 'AbstractUser', info: User) -> None:
changed = False
if self.username != info.username:
self.username = info.username
@@ -323,24 +344,26 @@ class Puppet:
if changed:
self.save()
async def update_displayname(self, source, info) -> None:
async def update_displayname(self, source: 'AbstractUser', info: User) -> bool:
ignore_source = (not source.is_relaybot
and self.displayname_source is not None
and self.displayname_source != source.tgid)
if ignore_source:
return
return False
displayname = self.get_displayname(info)
if displayname != self.displayname:
await self.default_mxid_intent.set_display_name(displayname)
self.displayname = displayname
self.displayname_source = source.tgid
self.displayname_source = TelegramId(source.tgid)
return True
elif source.is_relaybot or self.displayname_source is None:
self.displayname_source = source.tgid
self.displayname_source = TelegramId(source.tgid)
return True
else:
return False
async def update_avatar(self, source, photo) -> None:
async def update_avatar(self, source: 'AbstractUser', photo: FileLocation) -> bool:
photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id:
file = await util.transfer_file_to_matrix(self.db, source.client,
@@ -355,7 +378,7 @@ class Puppet:
# region Getters
@classmethod
def get(cls, tgid, create=True) -> "Optional[Puppet]":
def get(cls, tgid: TelegramId, create: bool = True) -> Optional['Puppet']:
try:
return cls.cache[tgid]
except KeyError:
@@ -374,12 +397,15 @@ class Puppet:
return None
@classmethod
def get_by_mxid(cls, mxid, create=True) -> "Optional[Puppet]":
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['Puppet']:
tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None
if tgid:
return cls.get(tgid, create)
return None
@classmethod
def get_by_custom_mxid(cls, mxid) -> None:
def get_by_custom_mxid(cls, mxid: MatrixUserId) -> Optional['Puppet']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -396,25 +422,25 @@ class Puppet:
return None
@classmethod
def get_all_with_custom_mxid(cls) -> None:
def get_all_with_custom_mxid(cls) -> List['Puppet']:
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod
def get_id_from_mxid(cls, mxid) -> None:
def get_id_from_mxid(cls, mxid: MatrixUserId) -> Optional[TelegramId]:
match = cls.mxid_regex.match(mxid)
if match:
return int(match.group(1))
return TelegramId(int(match.group(1)))
return None
@classmethod
def get_mxid_from_id(cls, tgid) -> None:
return f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}"
def get_mxid_from_id(cls, tgid: TelegramId) -> MatrixUserId:
return MatrixUserId(f"@{cls.username_template.format(userid=tgid)}:{cls.hs_domain}")
@classmethod
def find_by_username(cls, username) -> "Optional[Puppet]":
def find_by_username(cls, username: str) -> Optional['Puppet']:
if not username:
return None
@@ -422,14 +448,14 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower():
return puppet
puppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
@classmethod
def find_by_displayname(cls, displayname) -> "Optional[Puppet]":
def find_by_displayname(cls, displayname: str) -> Optional['Puppet']:
if not displayname:
return None
@@ -437,17 +463,17 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname:
return puppet
puppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if puppet:
return cls.from_db(puppet)
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none()
if dbpuppet:
return cls.from_db(dbpuppet)
return None
# endregion
def init(context: "Context") -> List[Awaitable[int]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"]