Start moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 01:19:12 +02:00
parent c028e1befc
commit 53489e7356
5 changed files with 216 additions and 127 deletions
+40 -1
View File
@@ -1,2 +1,41 @@
from abc import abstractmethod
from sqlalchemy import Table
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() # type: declarative_base
class BaseBase:
db = None # type: Engine
t = None # type: Table
__table__ = None # type: Table
c = None # type: ImmutableColumnCollection
@classmethod
@abstractmethod
def _one_or_none(cls, rows: RowProxy):
pass
@classmethod
def _select_one_or_none(cls, *args):
return cls._one_or_none(cls.db.execute(cls.t.select().where(*args)))
@property
@abstractmethod
def _edit_identity(self):
pass
def update(self, **values) -> None:
self.db.execute(self.t.update()
.where(self._edit_identity)
.values(**values))
for key, value in values.items():
setattr(self, key, value)
def delete(self) -> None:
self.db.execute(self.t.delete().where(self._edit_identity))
Base = declarative_base(cls=BaseBase)
+26
View File
@@ -0,0 +1,26 @@
from abc import abstractmethod
from sqlalchemy import Table
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base
class Base(declarative_base):
db: Engine
t: Table
__table__: Table
c: ImmutableColumnCollection
@classmethod
@abstractmethod
def _one_or_none(cls, rows: RowProxy): ...
@classmethod
def _select_one_or_none(cls, *args): ...
def _edit_identity(self): ...
def update(self, **values) -> None: ...
def delete(self) -> None: ...
+120 -77
View File
@@ -15,13 +15,12 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text, Table, BigInteger, String, Boolean, Text,
and_, func, select) and_, func, select)
from sqlalchemy.engine import Engine, RowProxy from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection from typing import Dict, Optional, List, Iterable
from typing import Dict, Optional, List
import json import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -30,7 +29,6 @@ from .base import Base
class Portal(Base): class Portal(Base):
query = None # type: Query
__tablename__ = "portal" __tablename__ = "portal"
# Telegram chat information # Telegram chat information
@@ -50,11 +48,41 @@ class Portal(Base):
about = Column(String, nullable=True) about = Column(String, nullable=True)
photo_id = Column(String, nullable=True) photo_id = Column(String, nullable=True)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try:
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
username, title, about, photo_id) = next(rows)
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
megagroup=megagroup, mxid=mxid, config=config, username=username,
title=title, about=about, photo_id=photo_id)
except StopIteration:
return None
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
def insert(self) -> None:
self.db.execute(self.t.insert().values(
tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
megagroup=self.megagroup, mxid=self.mxid, config=self.config, username=self.username,
title=self.title, about=self.about, photo_id=self.photo_id))
class Message(Base): class Message(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message" __tablename__ = "message"
mxid = Column(String) # type: MatrixEventID mxid = Column(String) # type: MatrixEventID
@@ -64,11 +92,11 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),) __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod @classmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']: def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
try: try:
mxid, mx_room, tgid, tg_space = next(rows) mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space) return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration: except StopIteration:
return None return None
@@ -79,9 +107,7 @@ class Message(Base):
@classmethod @classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']: def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select() return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space))
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod @classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int: def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
@@ -96,9 +122,9 @@ class Message(Base):
@classmethod @classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']: ) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where( return cls._select_one_or_none(and_(cls.c.mxid == mxid,
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space))) cls.c.mx_room == mx_room,
return cls._one_or_none(rows) cls.c.tg_space == tg_space))
@classmethod @classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None: def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
@@ -112,36 +138,16 @@ class Message(Base):
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room)) .where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values)) .values(**values))
def update(self, **values) -> None: @property
for key, value in values.items(): def _edit_identity(self):
setattr(self, key, value) return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
def insert(self) -> None: def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid, self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space)) tg_space=self.tg_space))
class UserPortal(Base):
query = None # type: Query
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class User(Base): class User(Base):
query = None # type: Query
__tablename__ = "user" __tablename__ = "user"
mxid = Column(String, primary_key=True) # type: MatrixUserID mxid = Column(String, primary_key=True) # type: MatrixUserID
@@ -154,11 +160,66 @@ class User(Base):
) # type: List[Contact] ) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal") portals = relationship("Portal", secondary="user_portal")
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
try:
mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
except StopIteration:
return None
@classmethod
def get_all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select())
for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row
yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
return cls._select_one_or_none(cls.c.tgid == tgid)
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['User']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['User']:
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
return self.c.mxid == self.mxid
def insert(self) -> None:
self.db.execute(self.t.insert().values(
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts))
class UserPortal(Base):
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
onupdate="CASCADE", ondelete="CASCADE"),)
class Contact(Base):
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class RoomState(Base): class RoomState(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_room_state" __tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True) # type: MatrixRoomID room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -177,18 +238,17 @@ class RoomState(Base):
rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id)) rows = cls.db.execute(cls.t.select().where(cls.c.room_id == room_id))
try: try:
room_id, power_levels_text = next(rows) room_id, power_levels_text = next(rows)
return RoomState(room_id=room_id, power_levels=(json.loads(power_levels_text) return cls(room_id=room_id, power_levels=(json.loads(power_levels_text)
if power_levels_text else None)) if power_levels_text else None))
except StopIteration: except StopIteration:
return None return None
def update(self) -> None: def update(self) -> None:
self.db.execute(self.t.update() return super().update(power_levels=self._power_levels_text)
.where(self.c.room_id == self.room_id)
.values(power_levels=self._power_levels_text))
def delete(self) -> None: @property
self.db.execute(self.t.delete().where(self.c.room_id == self.room_id)) def _edit_identity(self):
return self.c.room_id == self.room_id
def insert(self) -> None: def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, self.db.execute(self.t.insert().values(room_id=self.room_id,
@@ -196,9 +256,6 @@ class RoomState(Base):
class UserProfile(Base): class UserProfile(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "mx_user_profile" __tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True) # type: MatrixRoomID room_id = Column(String, primary_key=True) # type: MatrixRoomID
@@ -220,8 +277,8 @@ class UserProfile(Base):
cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id))) cls.t.select().where(and_(cls.c.room_id == room_id, cls.c.user_id == user_id)))
try: try:
room_id, user_id, membership, displayname, avatar_url = next(rows) room_id, user_id, membership, displayname, avatar_url = next(rows)
return UserProfile(room_id=room_id, user_id=user_id, membership=membership, return cls(room_id=room_id, user_id=user_id, membership=membership,
displayname=displayname, avatar_url=avatar_url) displayname=displayname, avatar_url=avatar_url)
except StopIteration: except StopIteration:
return None return None
@@ -230,14 +287,12 @@ class UserProfile(Base):
cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id)) cls.db.execute(cls.t.delete().where(cls.c.room_id == room_id))
def update(self) -> None: def update(self) -> None:
self.db.execute(self.t.update() super().update(membership=self.membership, displayname=self.displayname,
.where(and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)) avatar_url=self.avatar_url)
.values(membership=self.membership, displayname=self.displayname,
avatar_url=self.avatar_url))
def delete(self) -> None: @property
self.db.execute(self.t.delete().where(and_(self.c.room_id == self.room_id, def _edit_identity(self):
self.c.user_id == self.user_id))) return and_(self.c.room_id == self.room_id, self.c.user_id == self.user_id)
def insert(self) -> None: def insert(self) -> None:
self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id, self.db.execute(self.t.insert().values(room_id=self.room_id, user_id=self.user_id,
@@ -246,14 +301,6 @@ class UserProfile(Base):
avatar_url=self.avatar_url)) avatar_url=self.avatar_url))
class Contact(Base):
query = None # type: Query
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
class Puppet(Base): class Puppet(Base):
query = None # type: Query query = None # type: Query
__tablename__ = "puppet" __tablename__ = "puppet"
@@ -278,9 +325,6 @@ class BotChat(Base):
class TelegramFile(Base): class TelegramFile(Base):
db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "telegram_file" __tablename__ = "telegram_file"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
@@ -302,8 +346,8 @@ class TelegramFile(Base):
thumb = None thumb = None
if thumb_id: if thumb_id:
thumb = cls.get(thumb_id) thumb = cls.get(thumb_id)
return TelegramFile(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts, return cls(id=id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb) size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
except StopIteration: except StopIteration:
return None return None
@@ -316,8 +360,7 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None: def init(db_session, db_engine) -> None:
query = db_session.query_property() query = db_session.query_property()
for table in (Portal, Message, UserPortal, User, Puppet, BotChat, TelegramFile, UserProfile, for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
RoomState):
table.query = query table.query = query
table.db = db_engine table.db = db_engine
table.t = table.__table__ table.t = table.__table__
+15 -26
View File
@@ -31,7 +31,6 @@ import json
import re import re
import magic import magic
from sqlalchemy import orm
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from telethon.tl.functions.messages import ( from telethon.tl.functions.messages import (
@@ -89,7 +88,6 @@ InviteList = Union[MatrixUserID, List[MatrixUserID]]
class Portal: class Portal:
log = logging.getLogger("mau.portal") # type: logging.Logger log = logging.getLogger("mau.portal") # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService az = None # type: AppService
bot = None # type: Bot bot = None # type: Bot
loop = None # type: asyncio.AbstractEventLoop loop = None # type: asyncio.AbstractEventLoop
@@ -1255,8 +1253,7 @@ class Portal:
self.tg_receiver = self.tgid self.tg_receiver = self.tgid
self.by_tgid[self.tgid_full] = self self.by_tgid[self.tgid_full] = self
await self.update_info(source, entity) await self.update_info(source, entity)
self.db.add(self.db_instance) self.db_instance.insert()
self.save()
if self.bot and self.bot.tgid in invites: if self.bot and self.bot.tgid in invites:
self.bot.add_chat(self.tgid, self.peer_type) self.bot.add_chat(self.tgid, self.peer_type)
@@ -1842,15 +1839,13 @@ class Portal:
del self.by_tgid[self.tgid_full] del self.by_tgid[self.tgid_full]
except KeyError: except KeyError:
pass pass
self.tgid = new_id existing = self.by_tgid[(new_id, new_id)]
self.tg_receiver = new_id
existing = self.by_tgid[self.tgid_full]
if existing: if existing:
existing.delete() existing.delete()
self.db_instance.update(tgid=new_id, tg_receiver=new_id)
self.tgid = new_id
self.tg_receiver = new_id
self.by_tgid[self.tgid_full] = self self.by_tgid[self.tgid_full] = self
self.db_instance.tgid = self.tgid
self.db_instance.tg_receiver = self.tg_receiver
self.save()
def migrate_and_save_matrix(self, new_id: MatrixRoomID) -> None: def migrate_and_save_matrix(self, new_id: MatrixRoomID) -> None:
try: try:
@@ -1858,17 +1853,13 @@ class Portal:
except KeyError: except KeyError:
pass pass
self.mxid = new_id self.mxid = new_id
self.db_instance.update(mxid=self.mxid)
self.by_mxid[self.mxid] = self self.by_mxid[self.mxid] = self
self.save()
def save(self) -> None: def save(self) -> None:
self.db_instance.mxid = self.mxid self.db_instance.update(mxid=self.mxid, username=self.username, title=self.title,
self.db_instance.username = self.username about=self.about, photo_id=self.photo_id,
self.db_instance.title = self.title config=json.dumps(self.local_config))
self.db_instance.about = self.about
self.db_instance.photo_id = self.photo_id
self.db_instance.config = json.dumps(self.local_config)
self.db.commit()
def delete(self) -> None: def delete(self) -> None:
try: try:
@@ -1880,8 +1871,7 @@ class Portal:
except KeyError: except KeyError:
pass pass
if self._db_instance: if self._db_instance:
self.db.delete(self._db_instance) self._db_instance.delete()
self.db.commit()
self.deleted = True self.deleted = True
@classmethod @classmethod
@@ -1902,7 +1892,7 @@ class Portal:
except KeyError: except KeyError:
pass pass
portal = DBPortal.query.filter(DBPortal.mxid == mxid).one_or_none() portal = DBPortal.get_by_mxid(mxid)
if portal: if portal:
return cls.from_db(portal) return cls.from_db(portal)
@@ -1924,7 +1914,7 @@ class Portal:
if portal.username and portal.username.lower() == username.lower(): if portal.username and portal.username.lower() == username.lower():
return portal return portal
dbportal = DBPortal.query.filter(DBPortal.username == username).one_or_none() dbportal = DBPortal.get_by_username(username)
if dbportal: if dbportal:
return cls.from_db(dbportal) return cls.from_db(dbportal)
@@ -1940,14 +1930,13 @@ class Portal:
except KeyError: except KeyError:
pass pass
portal = DBPortal.query.get(tgid_full) portal = DBPortal.get_by_tgid(tgid, tg_receiver)
if portal: if portal:
return cls.from_db(portal) return cls.from_db(portal)
if peer_type: if peer_type:
portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver) portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver)
cls.db.add(portal.db_instance) portal.db_instance.insert()
cls.db.commit()
return portal return portal
return None return None
@@ -1987,7 +1976,7 @@ class Portal:
def init(context: Context) -> None: def init(context: Context) -> None:
global config global config
Portal.az, Portal.db, config, Portal.loop, Portal.bot = context.core Portal.az, _, config, Portal.loop, Portal.bot = context.core
Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"] Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"]
Portal.sync_channel_members = config["bridge.sync_channel_members"] Portal.sync_channel_members = config["bridge.sync_channel_members"]
Portal.sync_matrix_state = config["bridge.sync_matrix_state"] Portal.sync_matrix_state = config["bridge.sync_matrix_state"]
+15 -23
View File
@@ -14,7 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING from typing import Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
import re import re
@@ -101,20 +101,19 @@ class User(AbstractUser):
return self.displayname return self.displayname
@property @property
def db_contacts(self) -> List[DBContact]: def db_contacts(self) -> Iterable[DBContact]:
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id)) return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts)
for puppet in self.contacts]
@db_contacts.setter @db_contacts.setter
def db_contacts(self, contacts: List[DBContact]) -> None: def db_contacts(self, contacts: Iterable[DBContact]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else [] self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else []
@property @property
def db_portals(self) -> List[DBPortal]: def db_portals(self) -> Iterable[DBPortal]:
return [portal.db_instance for portal in self.portals.values() if not portal.deleted] return (portal.db_instance for portal in self.portals.values() if not portal.deleted)
@db_portals.setter @db_portals.setter
def db_portals(self, portals: List[DBPortal]) -> None: def db_portals(self, portals: Iterable[DBPortal]) -> None:
self.portals = { self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid, (portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid,
portal.tg_receiver) portal.tg_receiver)
@@ -135,13 +134,8 @@ class User(AbstractUser):
portals=self.db_portals) portals=self.db_portals)
def save(self) -> None: def save(self) -> None:
self.db_instance.tgid = self.tgid self.db_instance.update(tgid=self.tgid, tg_username=self.username, tg_phone=self.phone,
self.db_instance.tg_username = self.username saved_contacts=self.saved_contacts)
self.db_instance.tg_phone = self.phone
self.db_instance.contacts = self.db_contacts
self.db_instance.saved_contacts = self.saved_contacts
self.db_instance.portals = self.db_portals
self.db.commit()
def delete(self) -> None: def delete(self) -> None:
try: try:
@@ -150,8 +144,7 @@ class User(AbstractUser):
except KeyError: except KeyError:
pass pass
if self._db_instance: if self._db_instance:
self.db.delete(self._db_instance) self._db_instance.delete()
self.db.commit()
@classmethod @classmethod
def from_db(cls, db_user: DBUser) -> 'User': def from_db(cls, db_user: DBUser) -> 'User':
@@ -358,15 +351,14 @@ class User(AbstractUser):
except KeyError: except KeyError:
pass pass
user = DBUser.query.get(mxid) user = DBUser.get_by_mxid(mxid)
if user: if user:
user = cls.from_db(user) user = cls.from_db(user)
return user return user
if create: if create:
user = cls(mxid) user = cls(mxid)
cls.db.add(user.db_instance) user.db_instance.insert()
cls.db.commit()
return user return user
return None return None
@@ -378,7 +370,7 @@ class User(AbstractUser):
except KeyError: except KeyError:
pass pass
user = DBUser.query.filter(DBUser.tgid == tgid).one_or_none() user = DBUser.get_by_tgid(tgid)
if user: if user:
user = cls.from_db(user) user = cls.from_db(user)
return user return user
@@ -394,7 +386,7 @@ class User(AbstractUser):
if user.username and user.username.lower() == username.lower(): if user.username and user.username.lower() == username.lower():
return user return user
puppet = DBUser.query.filter(DBUser.tg_username == username).one_or_none() puppet = DBUser.get_by_username(username)
if puppet: if puppet:
return cls.from_db(puppet) return cls.from_db(puppet)
@@ -406,5 +398,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config global config
config = context.config config = context.config
users = [User.from_db(user) for user in DBUser.query.all()] users = [User.from_db(user) for user in DBUser.get_all()]
return [user.ensure_started() for user in users if user.tgid] return [user.ensure_started() for user in users if user.tgid]