Finish moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 14:42:03 +02:00
parent 53489e7356
commit cf847d3b8e
3 changed files with 118 additions and 40 deletions
+15 -16
View File
@@ -48,9 +48,9 @@ class User(AbstractUser):
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None,
db_contacts: Optional[Iterable[TelegramID]] = None,
saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[List[DBPortal]] = None,
db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None,
db_instance: Optional[DBUser] = None) -> None:
super().__init__()
self.mxid = mxid # type: MatrixUserID
@@ -60,9 +60,9 @@ class User(AbstractUser):
self.phone = phone # type: str
self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact]
self.portals = {} # type: Dict[Tuple[int, int], po.Portal]
self.db_portals = db_portals or [] # type: List[DBPortal]
self.db_contacts = db_contacts
self.portals = {} # type: Dict[Tuple[TelegramID, TelegramID], po.Portal]
self.db_portals = db_portals or []
self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: Dict
@@ -101,23 +101,22 @@ class User(AbstractUser):
return self.displayname
@property
def db_contacts(self) -> Iterable[DBContact]:
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
def db_contacts(self) -> Iterable[TelegramID]:
return (puppet.id for puppet in self.contacts)
@db_contacts.setter
def db_contacts(self, contacts: Iterable[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
def db_contacts(self, contacts: Iterable[TelegramID]) -> None:
self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else []
@property
def db_portals(self) -> Iterable[DBPortal]:
return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
return (portal.tgid_full for portal in self.portals.values() if not portal.deleted)
@db_portals.setter
def db_portals(self, portals: Iterable[DBPortal]) -> None:
def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver)
for portal in portals
tgid_full: po.Portal.get_by_tgid(*tgid_full)
for tgid_full in portals
} if portals else {}
# region Database conversion
@@ -398,5 +397,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config
config = context.config
users = [User.from_db(user) for user in DBUser.get_all()]
users = [User.from_db(user) for user in DBUser.all()]
return [user.ensure_started() for user in users if user.tgid]