Add missing type hints and fix most type errors except for Optionals.

This commit is contained in:
Kai A. Hiller
2018-08-09 02:19:55 +02:00
parent 01e153662e
commit 0f8009b1e9
26 changed files with 505 additions and 384 deletions
+26 -22
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Match, Optional, Tuple, TYPE_CHECKING
from typing import Coroutine, Dict, List, Match, Optional, Tuple, cast, TYPE_CHECKING
import logging
import asyncio
import re
@@ -28,6 +28,7 @@ from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from telethon.tl.functions.account import UpdateStatusRequest
from mautrix_appservice import MatrixRequestError
from .types import MatrixUserId, TelegramId
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
from .abstract_user import AbstractUser
from . import portal as po, puppet as pu
@@ -46,23 +47,23 @@ class User(AbstractUser):
by_mxid = {} # type: Dict[str, User]
by_tgid = {} # type: Dict[int, User]
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
def __init__(self, mxid: MatrixUserId, tgid: Optional[TelegramId] = None,
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: str
self.tgid = tgid # type: int
self.mxid = mxid # type: MatrixUserId
self.tgid = tgid # type: TelegramId
self.is_bot = is_bot # type: bool
self.username = username # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals # type: List[DBPortal]
self._db_instance = db_instance # type: DBUser
self.db_portals = db_portals or [] # type: List[DBPortal]
self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: dict
self.command_status = None # type: Dict
(self.relaybot_whitelisted,
self.whitelisted,
@@ -169,9 +170,9 @@ class User(AbstractUser):
except Exception:
self.log.exception("Failed to run post-login functions for %s", self.mxid)
async def update(self, update: TypeUpdate) -> None:
async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot:
return
return False
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
message = update.message
@@ -185,19 +186,22 @@ class User(AbstractUser):
elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
return
return False
self.register_portal(portal)
if portal:
self.register_portal(portal)
return True
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
def set_presence(self, online: bool = True) -> None:
def set_presence(self, online: bool = True) -> bool:
if self.is_bot:
return
return False
return self.client(UpdateStatusRequest(offline=not online))
async def update_info(self, info: TLUser = None) -> None:
@@ -215,7 +219,7 @@ class User(AbstractUser):
if changed:
self.save()
async def log_out(self) -> None:
async def log_out(self) -> bool:
puppet = pu.Puppet.get(self.tgid)
if puppet.is_real_user:
await puppet.switch_mxid(None, None)
@@ -328,7 +332,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['User']:
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -351,7 +355,7 @@ class User(AbstractUser):
return None
@classmethod
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
def get_by_tgid(cls, tgid: int) -> Optional['User']:
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -365,7 +369,7 @@ class User(AbstractUser):
return None
@classmethod
def find_by_username(cls, username: str) -> "Optional[User]":
def find_by_username(cls, username: str) -> Optional['User']:
if not username:
return None
@@ -381,7 +385,7 @@ class User(AbstractUser):
# endregion
def init(context: "Context") -> List[Awaitable[User]]:
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
global config
config = context.config