Add missing type hints and fix most type errors except for Optionals.
This commit is contained in:
+26
-22
@@ -14,7 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Awaitable, Dict, List, Match, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Coroutine, Dict, List, Match, Optional, Tuple, cast, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
@@ -28,6 +28,7 @@ from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
|
||||
from telethon.tl.functions.account import UpdateStatusRequest
|
||||
from mautrix_appservice import MatrixRequestError
|
||||
|
||||
from .types import MatrixUserId, TelegramId
|
||||
from .db import User as DBUser, Contact as DBContact, Portal as DBPortal
|
||||
from .abstract_user import AbstractUser
|
||||
from . import portal as po, puppet as pu
|
||||
@@ -46,23 +47,23 @@ class User(AbstractUser):
|
||||
by_mxid = {} # type: Dict[str, User]
|
||||
by_tgid = {} # type: Dict[int, User]
|
||||
|
||||
def __init__(self, mxid: str, tgid: Optional[int] = None, username: Optional[str] = None,
|
||||
db_contacts: Optional[List[DBContact]] = None, saved_contacts: int = 0,
|
||||
is_bot: bool = False, db_portals: Optional[List[DBPortal]] = None,
|
||||
def __init__(self, mxid: MatrixUserId, tgid: Optional[TelegramId] = None,
|
||||
username: Optional[str] = None, db_contacts: Optional[List[DBContact]] = None,
|
||||
saved_contacts: int = 0, is_bot: bool = False, db_portals: List[DBPortal] = [],
|
||||
db_instance: Optional[DBUser] = None) -> None:
|
||||
super().__init__()
|
||||
self.mxid = mxid # type: str
|
||||
self.tgid = tgid # type: int
|
||||
self.mxid = mxid # type: MatrixUserId
|
||||
self.tgid = tgid # type: TelegramId
|
||||
self.is_bot = is_bot # type: bool
|
||||
self.username = username # type: str
|
||||
self.contacts = [] # type: List[pu.Puppet]
|
||||
self.saved_contacts = saved_contacts # type: int
|
||||
self.db_contacts = db_contacts # type: List[DBContact]
|
||||
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
|
||||
self.db_portals = db_portals # type: List[DBPortal]
|
||||
self._db_instance = db_instance # type: DBUser
|
||||
self.db_portals = db_portals or [] # type: List[DBPortal]
|
||||
self._db_instance = db_instance # type: Optional[DBUser]
|
||||
|
||||
self.command_status = None # type: dict
|
||||
self.command_status = None # type: Dict
|
||||
|
||||
(self.relaybot_whitelisted,
|
||||
self.whitelisted,
|
||||
@@ -169,9 +170,9 @@ class User(AbstractUser):
|
||||
except Exception:
|
||||
self.log.exception("Failed to run post-login functions for %s", self.mxid)
|
||||
|
||||
async def update(self, update: TypeUpdate) -> None:
|
||||
async def update(self, update: TypeUpdate) -> bool:
|
||||
if not self.is_bot:
|
||||
return
|
||||
return False
|
||||
|
||||
if isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
|
||||
message = update.message
|
||||
@@ -185,19 +186,22 @@ class User(AbstractUser):
|
||||
elif isinstance(update, UpdateShortMessage):
|
||||
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
|
||||
else:
|
||||
return
|
||||
return False
|
||||
|
||||
self.register_portal(portal)
|
||||
if portal:
|
||||
self.register_portal(portal)
|
||||
|
||||
return True
|
||||
|
||||
# endregion
|
||||
# region Telegram actions that need custom methods
|
||||
|
||||
def ensure_started(self, even_if_no_session: bool = False) -> "Awaitable[User]":
|
||||
return super().ensure_started(even_if_no_session)
|
||||
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
|
||||
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
|
||||
|
||||
def set_presence(self, online: bool = True) -> None:
|
||||
def set_presence(self, online: bool = True) -> bool:
|
||||
if self.is_bot:
|
||||
return
|
||||
return False
|
||||
return self.client(UpdateStatusRequest(offline=not online))
|
||||
|
||||
async def update_info(self, info: TLUser = None) -> None:
|
||||
@@ -215,7 +219,7 @@ class User(AbstractUser):
|
||||
if changed:
|
||||
self.save()
|
||||
|
||||
async def log_out(self) -> None:
|
||||
async def log_out(self) -> bool:
|
||||
puppet = pu.Puppet.get(self.tgid)
|
||||
if puppet.is_real_user:
|
||||
await puppet.switch_mxid(None, None)
|
||||
@@ -328,7 +332,7 @@ class User(AbstractUser):
|
||||
# region Class instance lookup
|
||||
|
||||
@classmethod
|
||||
def get_by_mxid(cls, mxid: str, create: bool=True) -> "Optional[User]":
|
||||
def get_by_mxid(cls, mxid: MatrixUserId, create: bool = True) -> Optional['User']:
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
@@ -351,7 +355,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: int) -> "Optional[User]":
|
||||
def get_by_tgid(cls, tgid: int) -> Optional['User']:
|
||||
try:
|
||||
return cls.by_tgid[tgid]
|
||||
except KeyError:
|
||||
@@ -365,7 +369,7 @@ class User(AbstractUser):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_username(cls, username: str) -> "Optional[User]":
|
||||
def find_by_username(cls, username: str) -> Optional['User']:
|
||||
if not username:
|
||||
return None
|
||||
|
||||
@@ -381,7 +385,7 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
|
||||
|
||||
def init(context: "Context") -> List[Awaitable[User]]:
|
||||
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
|
||||
global config
|
||||
config = context.config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user