Finish moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 14:42:03 +02:00
parent 53489e7356
commit cf847d3b8e
3 changed files with 118 additions and 40 deletions
+93 -15
View File
@@ -20,7 +20,7 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstrai
from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query from sqlalchemy.orm import relationship, Query
from typing import Dict, Optional, List, Iterable from typing import Dict, Optional, List, Iterable, Tuple
import json import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -48,14 +48,18 @@ class Portal(Base):
about = Column(String, nullable=True) about = Column(String, nullable=True)
photo_id = Column(String, nullable=True) photo_id = Column(String, nullable=True)
@classmethod
def scan(cls, row) -> Optional['Portal']:
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
photo_id) = row
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
mxid=mxid, config=config, username=username, title=title, about=about,
photo_id=photo_id)
@classmethod @classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']: def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try: try:
(tgid, tg_receiver, peer_type, megagroup, mxid, config, return cls.scan(next(rows))
username, title, about, photo_id) = next(rows)
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
megagroup=megagroup, mxid=mxid, config=config, username=username,
title=title, about=about, photo_id=photo_id)
except StopIteration: except StopIteration:
return None return None
@@ -155,10 +159,6 @@ class User(Base):
tg_username = Column(String, nullable=True) tg_username = Column(String, nullable=True)
tg_phone = Column(String, nullable=True) tg_phone = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False) saved_contacts = Column(Integer, default=0, nullable=False)
contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan"
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@classmethod @classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']: def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
@@ -170,7 +170,7 @@ class User(Base):
return None return None
@classmethod @classmethod
def get_all(cls) -> Iterable['User']: def all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select()) rows = cls.db.execute(cls.t.select())
for row in rows: for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row mxid, tgid, tg_username, tg_phone, saved_contacts = row
@@ -198,6 +198,36 @@ class User(Base):
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone, mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts)) saved_contacts=self.saved_contacts))
@property
def contacts(self) -> Iterable[TelegramID]:
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
for row in rows:
user, contact = row
yield contact
@contacts.setter
def contacts(self, puppets: Iterable[TelegramID]) -> None:
self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid}
for tgid in puppets])
@property
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
for row in rows:
user, portal, portal_receiver = row
yield (portal, portal_receiver)
@portals.setter
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
self.db.execute(UserPortal.t.insert(),
[{
"user": self.tgid,
"portal": tgid,
"portal_receiver": tg_receiver
} for tgid, tg_receiver in portals])
class UserPortal(Base): class UserPortal(Base):
__tablename__ = "user_portal" __tablename__ = "user_portal"
@@ -302,7 +332,6 @@ class UserProfile(Base):
class Puppet(Base): class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet" __tablename__ = "puppet"
id = Column(Integer, primary_key=True) # type: TelegramID id = Column(Integer, primary_key=True) # type: TelegramID
@@ -315,6 +344,55 @@ class Puppet(Base):
is_bot = Column(Boolean, nullable=True) is_bot = Column(Boolean, nullable=True)
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false()) matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
@classmethod
def scan(cls, row) -> Optional['Puppet']:
(id, custom_mxid, access_token, displayname, displayname_source, username, photo_id,
is_bot, matrix_registered) = row
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token,
displayname=displayname, displayname_source=displayname_source,
username=username, photo_id=photo_id, is_bot=is_bot,
matrix_registered=matrix_registered)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']:
try:
return cls.scan(next(rows))
except StopIteration:
return None
@classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
for row in rows:
yield cls.scan(row)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.id == tgid)
@classmethod
def get_by_custom_mxid(cls, mxid: MatrixRoomID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.username == username)
@classmethod
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.displayname == displayname)
@property
def _edit_identity(self):
return self.c.id == self.id
def insert(self) -> None:
self.db.execute(self.t.insert().values(
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
displayname=self.displayname, displayname_source=self.displayname_source,
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
matrix_registered=self.matrix_registered))
# Fucking Telegram not telling bots what chats they are in 3:< # Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base): class BotChat(Base):
@@ -359,9 +437,9 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None: def init(db_session, db_engine) -> None:
query = db_session.query_property() BotChat.query = db_session.query_property()
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState): for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
table.query = query RoomState):
table.db = db_engine table.db = db_engine
table.t = table.__table__ table.t = table.__table__
table.c = table.t.c table.c = table.t.c
+10 -9
View File
@@ -14,7 +14,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Coroutine, Dict, List, Optional, Pattern, Union, TYPE_CHECKING from typing import (Awaitable, Coroutine, Dict, List, Iterable, Optional, Pattern, Union,
TYPE_CHECKING)
from difflib import SequenceMatcher from difflib import SequenceMatcher
from enum import Enum from enum import Enum
from aiohttp import ServerDisconnectedError from aiohttp import ServerDisconnectedError
@@ -396,7 +397,7 @@ class Puppet:
except KeyError: except KeyError:
pass pass
puppet = DBPuppet.query.get(tgid) puppet = DBPuppet.get_by_tgid(tgid)
if puppet: if puppet:
return cls.from_db(puppet) return cls.from_db(puppet)
@@ -426,7 +427,7 @@ class Puppet:
except KeyError: except KeyError:
pass pass
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none() puppet = DBPuppet.get_by_custom_mxid(mxid)
if puppet: if puppet:
puppet = cls.from_db(puppet) puppet = cls.from_db(puppet)
return puppet return puppet
@@ -434,11 +435,11 @@ class Puppet:
return None return None
@classmethod @classmethod
def get_all_with_custom_mxid(cls) -> List['Puppet']: def all_with_custom_mxid(cls) -> Iterable['Puppet']:
return [cls.by_custom_mxid[puppet.mxid] return (cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet) else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()] for puppet in DBPuppet.all_with_custom_mxid())
@classmethod @classmethod
def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]: def get_id_from_mxid(cls, mxid: MatrixUserID) -> Optional[TelegramID]:
@@ -460,7 +461,7 @@ class Puppet:
if puppet.username and puppet.username.lower() == username.lower(): if puppet.username and puppet.username.lower() == username.lower():
return puppet return puppet
dbpuppet = DBPuppet.query.filter(DBPuppet.username == username).one_or_none() dbpuppet = DBPuppet.get_by_username(username)
if dbpuppet: if dbpuppet:
return cls.from_db(dbpuppet) return cls.from_db(dbpuppet)
@@ -475,7 +476,7 @@ class Puppet:
if puppet.displayname and puppet.displayname == displayname: if puppet.displayname and puppet.displayname == displayname:
return puppet return puppet
dbpuppet = DBPuppet.query.filter(DBPuppet.displayname == displayname).one_or_none() dbpuppet = DBPuppet.get_by_displayname(displayname)
if dbpuppet: if dbpuppet:
return cls.from_db(dbpuppet) return cls.from_db(dbpuppet)
@@ -491,4 +492,4 @@ def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError]
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
Puppet.mxid_regex = re.compile( Puppet.mxid_regex = re.compile(
f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}") f"@{Puppet.username_template.format(userid='([0-9]+)')}:{Puppet.hs_domain}")
return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] return [puppet.init_custom_mxid() for puppet in Puppet.all_with_custom_mxid()]
+15 -16
View File
@@ -48,9 +48,9 @@ class User(AbstractUser):
def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None, def __init__(self, mxid: MatrixUserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None, username: Optional[str] = None, phone: Optional[str] = None,
db_contacts: Optional[List[DBContact]] = None, db_contacts: Optional[Iterable[TelegramID]] = None,
saved_contacts: int = 0, is_bot: bool = False, saved_contacts: int = 0, is_bot: bool = False,
db_portals: Optional[List[DBPortal]] = None, db_portals: Optional[Iterable[Tuple[TelegramID, TelegramID]]] = None,
db_instance: Optional[DBUser] = None) -> None: db_instance: Optional[DBUser] = None) -> None:
super().__init__() super().__init__()
self.mxid = mxid # type: MatrixUserID self.mxid = mxid # type: MatrixUserID
@@ -60,9 +60,9 @@ class User(AbstractUser):
self.phone = phone # type: str self.phone = phone # type: str
self.contacts = [] # type: List[pu.Puppet] self.contacts = [] # type: List[pu.Puppet]
self.saved_contacts = saved_contacts # type: int self.saved_contacts = saved_contacts # type: int
self.db_contacts = db_contacts # type: List[DBContact] self.db_contacts = db_contacts
self.portals = {} # type: Dict[Tuple[int, int], po.Portal] self.portals = {} # type: Dict[Tuple[TelegramID, TelegramID], po.Portal]
self.db_portals = db_portals or [] # type: List[DBPortal] self.db_portals = db_portals or []
self._db_instance = db_instance # type: Optional[DBUser] self._db_instance = db_instance # type: Optional[DBUser]
self.command_status = None # type: Dict self.command_status = None # type: Dict
@@ -101,23 +101,22 @@ class User(AbstractUser):
return self.displayname return self.displayname
@property @property
def db_contacts(self) -> Iterable[DBContact]: def db_contacts(self) -> Iterable[TelegramID]:
return (DBContact(user=self.tgid, contact=puppet.id) for puppet in self.contacts) return (puppet.id for puppet in self.contacts)
@db_contacts.setter @db_contacts.setter
def db_contacts(self, contacts: Iterable[DBContact]) -> None: def db_contacts(self, contacts: Iterable[TelegramID]) -> None:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts] if contacts else [] self.contacts = [pu.Puppet.get(entry) for entry in contacts] if contacts else []
@property @property
def db_portals(self) -> Iterable[DBPortal]: def db_portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
return (portal.db_instance for portal in self.portals.values() if not portal.deleted) return (portal.tgid_full for portal in self.portals.values() if not portal.deleted)
@db_portals.setter @db_portals.setter
def db_portals(self, portals: Iterable[DBPortal]) -> None: def db_portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.portals = { self.portals = {
(portal.tgid, portal.tg_receiver): po.Portal.get_by_tgid(portal.tgid, tgid_full: po.Portal.get_by_tgid(*tgid_full)
portal.tg_receiver) for tgid_full in portals
for portal in portals
} if portals else {} } if portals else {}
# region Database conversion # region Database conversion
@@ -398,5 +397,5 @@ def init(context: 'Context') -> List[Awaitable['User']]:
global config global config
config = context.config config = context.config
users = [User.from_db(user) for user in DBUser.get_all()] users = [User.from_db(user) for user in DBUser.all()]
return [user.ensure_started() for user in users if user.tgid] return [user.ensure_started() for user in users if user.tgid]