Finish moving portals and users to SQLAlchemy Core
This commit is contained in:
+93
-15
@@ -20,7 +20,7 @@ from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstrai
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.orm import relationship, Query
|
||||
from typing import Dict, Optional, List, Iterable
|
||||
from typing import Dict, Optional, List, Iterable, Tuple
|
||||
import json
|
||||
|
||||
from mautrix_telegram.types import MatrixUserID, MatrixRoomID, MatrixEventID
|
||||
@@ -48,14 +48,18 @@ class Portal(Base):
|
||||
about = Column(String, nullable=True)
|
||||
photo_id = Column(String, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row) -> Optional['Portal']:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
|
||||
photo_id) = row
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
|
||||
mxid=mxid, config=config, username=username, title=title, about=about,
|
||||
photo_id=photo_id)
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
|
||||
try:
|
||||
(tgid, tg_receiver, peer_type, megagroup, mxid, config,
|
||||
username, title, about, photo_id) = next(rows)
|
||||
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type,
|
||||
megagroup=megagroup, mxid=mxid, config=config, username=username,
|
||||
title=title, about=about, photo_id=photo_id)
|
||||
return cls.scan(next(rows))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@@ -155,10 +159,6 @@ class User(Base):
|
||||
tg_username = Column(String, nullable=True)
|
||||
tg_phone = Column(String, nullable=True)
|
||||
saved_contacts = Column(Integer, default=0, nullable=False)
|
||||
contacts = relationship("Contact", uselist=True,
|
||||
cascade="save-update, merge, delete, delete-orphan"
|
||||
) # type: List[Contact]
|
||||
portals = relationship("Portal", secondary="user_portal")
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
|
||||
@@ -170,7 +170,7 @@ class User(Base):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> Iterable['User']:
|
||||
def all(cls) -> Iterable['User']:
|
||||
rows = cls.db.execute(cls.t.select())
|
||||
for row in rows:
|
||||
mxid, tgid, tg_username, tg_phone, saved_contacts = row
|
||||
@@ -198,6 +198,36 @@ class User(Base):
|
||||
mxid=self.mxid, tgid=self.tgid, tg_username=self.tg_username, tg_phone=self.tg_phone,
|
||||
saved_contacts=self.saved_contacts))
|
||||
|
||||
@property
|
||||
def contacts(self) -> Iterable[TelegramID]:
|
||||
rows = self.db.execute(Contact.t.select().where(Contact.c.user == self.tgid))
|
||||
for row in rows:
|
||||
user, contact = row
|
||||
yield contact
|
||||
|
||||
@contacts.setter
|
||||
def contacts(self, puppets: Iterable[TelegramID]) -> None:
|
||||
self.db.execute(Contact.t.delete().where(Contact.c.user == self.tgid))
|
||||
self.db.execute(Contact.t.insert(), [{"user": self.tgid, "contact": tgid}
|
||||
for tgid in puppets])
|
||||
|
||||
@property
|
||||
def portals(self) -> Iterable[Tuple[TelegramID, TelegramID]]:
|
||||
rows = self.db.execute(UserPortal.t.select().where(UserPortal.c.user == self.tgid))
|
||||
for row in rows:
|
||||
user, portal, portal_receiver = row
|
||||
yield (portal, portal_receiver)
|
||||
|
||||
@portals.setter
|
||||
def portals(self, portals: Iterable[Tuple[TelegramID, TelegramID]]) -> None:
|
||||
self.db.execute(UserPortal.t.delete().where(UserPortal.c.user == self.tgid))
|
||||
self.db.execute(UserPortal.t.insert(),
|
||||
[{
|
||||
"user": self.tgid,
|
||||
"portal": tgid,
|
||||
"portal_receiver": tg_receiver
|
||||
} for tgid, tg_receiver in portals])
|
||||
|
||||
|
||||
class UserPortal(Base):
|
||||
__tablename__ = "user_portal"
|
||||
@@ -302,7 +332,6 @@ class UserProfile(Base):
|
||||
|
||||
|
||||
class Puppet(Base):
|
||||
query = None # type: Query
|
||||
__tablename__ = "puppet"
|
||||
|
||||
id = Column(Integer, primary_key=True) # type: TelegramID
|
||||
@@ -315,6 +344,55 @@ class Puppet(Base):
|
||||
is_bot = Column(Boolean, nullable=True)
|
||||
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
|
||||
|
||||
@classmethod
|
||||
def scan(cls, row) -> Optional['Puppet']:
|
||||
(id, custom_mxid, access_token, displayname, displayname_source, username, photo_id,
|
||||
is_bot, matrix_registered) = row
|
||||
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token,
|
||||
displayname=displayname, displayname_source=displayname_source,
|
||||
username=username, photo_id=photo_id, is_bot=is_bot,
|
||||
matrix_registered=matrix_registered)
|
||||
|
||||
@classmethod
|
||||
def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']:
|
||||
try:
|
||||
return cls.scan(next(rows))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
|
||||
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
|
||||
for row in rows:
|
||||
yield cls.scan(row)
|
||||
|
||||
@classmethod
|
||||
def get_by_tgid(cls, tgid: TelegramID) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.id == tgid)
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid: MatrixRoomID) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.username == username)
|
||||
|
||||
@classmethod
|
||||
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
|
||||
return cls._select_one_or_none(cls.c.displayname == displayname)
|
||||
|
||||
@property
|
||||
def _edit_identity(self):
|
||||
return self.c.id == self.id
|
||||
|
||||
def insert(self) -> None:
|
||||
self.db.execute(self.t.insert().values(
|
||||
id=self.id, custom_mxid=self.custom_mxid, access_token=self.access_token,
|
||||
displayname=self.displayname, displayname_source=self.displayname_source,
|
||||
username=self.username, photo_id=self.photo_id, is_bot=self.is_bot,
|
||||
matrix_registered=self.matrix_registered))
|
||||
|
||||
|
||||
# Fucking Telegram not telling bots what chats they are in 3:<
|
||||
class BotChat(Base):
|
||||
@@ -359,9 +437,9 @@ class TelegramFile(Base):
|
||||
|
||||
|
||||
def init(db_session, db_engine) -> None:
|
||||
query = db_session.query_property()
|
||||
for table in (Portal, Message, User, Puppet, BotChat, TelegramFile, UserProfile, RoomState):
|
||||
table.query = query
|
||||
BotChat.query = db_session.query_property()
|
||||
for table in (Portal, Message, User, Contact, UserPortal, Puppet, TelegramFile, UserProfile,
|
||||
RoomState):
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
table.c = table.t.c
|
||||
|
||||
Reference in New Issue
Block a user