Finish moving portals and users to SQLAlchemy Core

This commit is contained in:
Tulir Asokan
2019-02-12 14:42:03 +02:00
parent 53489e7356
commit cf847d3b8e
3 changed files with 118 additions and 40 deletions
+93 -15
View File
@@ -20,7 +20,7 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstrai
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query
from typing import Dict, Optional, List, Iterable
from typing import Dict, Optional, List, Iterable, Tuple
import json
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
@@ -48,14 +48,18 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
@classmethod
def scan(cls, row) -> Optional['Portal']:
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
photo_id) = row
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
mxid=mxid, config=config, username=username, title=title, about=about,
photo_id=photo_id)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try:
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
username, title, about, photo_id) = next(rows)
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
megagroup=megagroup, mxid=mxid, config=config, username=username,
title=title, about=about, photo_id=photo_id)
return cls.scan(next(rows))
except StopIteration:
return None
@@ -155,10 +159,6 @@ class User(Base):
tg_username = Column(String, nullable=True)
tg_phone = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False)
contacts = relationship("Contact", uselist=True,
cascade="save-update, merge, delete, delete-orphan"
) # type: List[Contact]
portals = relationship("Portal", secondary="user_portal")
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
@@ -170,7 +170,7 @@ class User(Base):
return None
@classmethod
def get_all(cls) -> Iterable['User']:
def all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select())
for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row
@@ -198,6 +198,36 @@ class User(Base):
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
saved_contacts=self.saved_contacts))
@property
def contacts(self) -> Iterable[TelegramID]:
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
for row in rows:
user, contact = row
yield contact
@contacts.setter
def contacts(self, puppets: Iterable[TelegramID]) -> None:
self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid}
for tgid in puppets])
@property
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
for row in rows:
user, portal, portal_receiver = row
yield (portal, portal_receiver)
@portals.setter
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
self.db.execute(UserPortal.t.insert(),
[{
"user": self.tgid,
"portal": tgid,
"portal_receiver": tg_receiver
} for tgid, tg_receiver in portals])
class UserPortal(Base):
__tablename__ = "user_portal"
@@ -302,7 +332,6 @@ class UserProfile(Base):
class Puppet(Base):
query = None # type: Query
__tablename__ = "puppet"
id = Column(Integer, primary_key=True) # type: TelegramID
@@ -315,6 +344,55 @@ class Puppet(Base):
is_bot = Column(Boolean, nullable=True)
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
@classmethod
def scan(cls, row) -> Optional['Puppet']:
(id, custom_mxid, access_token, displayname, displayname_source, username, photo_id,
is_bot, matrix_registered) = row
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token,
displayname=displayname, displayname_source=displayname_source,
username=username, photo_id=photo_id, is_bot=is_bot,
matrix_registered=matrix_registered)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']:
try:
return cls.scan(next(rows))
except StopIteration:
return None
@classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
for row in rows:
yield cls.scan(row)
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.id == tgid)
@classmethod
def get_by_custom_mxid(cls, mxid: MatrixRoomID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
@classmethod
def get_by_username(cls, username: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.username == username)
@classmethod
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.displayname == displayname)
@property
def _edit_identity(self):
return self.c.id == self.id
def insert(self) -> None:
self.db.execute(self.t.insert().values(
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
displayname=self.displayname, displayname_source=self.displayname_source,
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
matrix_registered=self.matrix_registered))
# Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base):
@@ -359,9 +437,9 @@ class TelegramFile(Base):
def init(db_session, db_engine) -> None:
query = db_session.query_property()
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
table.query = query
BotChat.query = db_session.query_property()
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
RoomState):
table.db = db_engine
table.t = table.__table__
table.c = table.t.c