Even even more migrations to mautrix-python

This commit is contained in:
Tulir Asokan
2019-08-04 01:41:10 +03:00
parent d521bbc0fa
commit d8653961af
34 changed files with 475 additions and 946 deletions
+2 -1
View File
@@ -13,7 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .base import Base
from mautrix.bridge.db import UserProfile, RoomState
from .bot_chat import BotChat
from .message import Message
from .portal import Portal
-58
View File
@@ -1,58 +0,0 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from abc import abstractmethod
from sqlalchemy import Table
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base
class BaseBase:
db: Engine = None
t: Table = None
__table__: Table = None
c: ImmutableColumnCollection = None
@classmethod
@abstractmethod
def _one_or_none(cls, rows: RowProxy):
pass
@classmethod
def _select_one_or_none(cls, *args):
return cls._one_or_none(cls.db.execute(cls.t.select().where(*args)))
@property
@abstractmethod
def _edit_identity(self):
pass
def update(self, **values) -> None:
with self.db.begin() as conn:
conn.execute(self.t.update()
.where(self._edit_identity)
.values(**values))
for key, value in values.items():
setattr(self, key, value)
def delete(self) -> None:
with self.db.begin() as conn:
conn.execute(self.t.delete().where(self._edit_identity))
Base = declarative_base(cls=BaseBase)
-26
View File
@@ -1,26 +0,0 @@
from abc import abstractmethod
from sqlalchemy import Table
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.ext.declarative import declarative_base
class Base(declarative_base):
db: Engine
t: Table
__table__: Table
c: ImmutableColumnCollection
@classmethod
@abstractmethod
def _one_or_none(cls, rows: RowProxy): ...
@classmethod
def _select_one_or_none(cls, *args): ...
def _edit_identity(self): ...
def update(self, **values) -> None: ...
def delete(self) -> None: ...
+11 -8
View File
@@ -16,28 +16,31 @@
from typing import Iterable
from sqlalchemy import Column, Integer, String
from sqlalchemy.engine.result import RowProxy
from mautrix.bridge.db import Base
from ..types import TelegramID
from .base import Base
# Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base):
__tablename__ = "bot_chat"
id = Column(Integer, primary_key=True) # type: TelegramID
type = Column(String, nullable=False)
id: TelegramID = Column(Integer, primary_key=True)
type: str = Column(String, nullable=False)
@classmethod
def delete(cls, chat_id: TelegramID) -> None:
def delete_by_id(cls, chat_id: TelegramID) -> None:
with cls.db.begin() as conn:
conn.execute(cls.t.delete().where(cls.c.id == chat_id))
@classmethod
def scan(cls, row: RowProxy) -> 'BotChat':
return cls(id=row[0], type=row[1])
@classmethod
def all(cls) -> Iterable['BotChat']:
rows = cls.db.execute(cls.t.select())
for row in rows:
chat_id, chat_type = row
yield cls(id=chat_id, type=chat_type)
return cls._select_all()
def insert(self) -> None:
with self.db.begin() as conn:
+19 -26
View File
@@ -13,42 +13,35 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Iterator
from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select
from sqlalchemy.engine.result import RowProxy
from typing import Optional, List
from sqlalchemy.sql.expression import ClauseElement
from ..types import MatrixRoomID, MatrixEventID, TelegramID
from .base import Base
from mautrix.types import RoomID, EventID
from mautrix.bridge.db import Base
from ..types import TelegramID
class Message(Base):
__tablename__ = "message"
mxid = Column(String) # type: MatrixEventID
mx_room = Column(String) # type: MatrixRoomID
tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_space = Column(Integer, primary_key=True) # type: TelegramID
edit_index = Column(Integer, primary_key=True) # type: int
mxid: EventID = Column(String)
mx_room: RoomID = Column(String)
tgid: TelegramID = Column(Integer, primary_key=True)
tg_space: TelegramID = Column(Integer, primary_key=True)
edit_index: int = Column(Integer, primary_key=True)
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space, edit_index = next(rows)
return cls(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space,
edit_index=edit_index)
except StopIteration:
return None
@staticmethod
def _all(rows: RowProxy) -> List['Message']:
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3],
edit_index=row[4])
for row in rows]
def scan(cls, row: RowProxy) -> 'Message':
return cls(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3], edit_index=row[4])
@classmethod
def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> List['Message']:
def get_all_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Iterator['Message']:
return cls._all(cls.db.execute(cls.t.select().where(and_(cls.c.tgid == tgid,
cls.c.tg_space == tg_space))))
@@ -68,7 +61,7 @@ class Message(Base):
return cls._one_or_none(cls.db.execute(query))
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int:
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
try:
@@ -78,7 +71,7 @@ class Message(Base):
return 0
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID
) -> Optional['Message']:
return cls._select_one_or_none(and_(cls.c.mxid == mxid,
cls.c.mx_room == mx_room,
@@ -94,14 +87,14 @@ class Message(Base):
.values(**values))
@classmethod
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
def update_by_mxid(cls, s_mxid: EventID, s_mx_room: RoomID, **values) -> None:
with cls.db.begin() as conn:
conn.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
@property
def _edit_identity(self):
def _edit_identity(self) -> ClauseElement:
return and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space,
self.c.edit_index == self.edit_index)
+21 -24
View File
@@ -13,55 +13,52 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, Integer, String, Boolean, Text, and_
from sqlalchemy.engine.result import RowProxy
from typing import Optional
from ..types import MatrixRoomID, TelegramID
from .base import Base
from sqlalchemy import Column, Integer, String, Boolean, Text, and_
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from mautrix.types import RoomID
from mautrix.bridge.db import Base
from ..types import TelegramID
class Portal(Base):
__tablename__ = "portal"
# Telegram chat information
tgid = Column(Integer, primary_key=True) # type: TelegramID
tg_receiver = Column(Integer, primary_key=True) # type: TelegramID
peer_type = Column(String, nullable=False)
megagroup = Column(Boolean)
tgid: TelegramID = Column(Integer, primary_key=True)
tg_receiver: TelegramID = Column(Integer, primary_key=True)
peer_type: str = Column(String, nullable=False)
megagroup: bool = Column(Boolean)
# Matrix portal information
mxid = Column(String, unique=True, nullable=True) # type: Optional[MatrixRoomID]
mxid: RoomID = Column(String, unique=True, nullable=True)
config = Column(Text, nullable=True)
config: str = Column(Text, nullable=True)
# Telegram chat metadata
username = Column(String, nullable=True)
title = Column(String, nullable=True)
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
username: str = Column(String, nullable=True)
title: str = Column(String, nullable=True)
about: str = Column(String, nullable=True)
photo_id: str = Column(String, nullable=True)
@classmethod
def scan(cls, row) -> Optional['Portal']:
def scan(cls, row: RowProxy) -> Optional['Portal']:
(tgid, tg_receiver, peer_type, megagroup, mxid, config, username, title, about,
photo_id) = row
return cls(tgid=tgid, tg_receiver=tg_receiver, peer_type=peer_type, megagroup=megagroup,
mxid=mxid, config=config, username=username, title=title, about=about,
photo_id=photo_id)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Portal']:
try:
return cls.scan(next(rows))
except StopIteration:
return None
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_receiver: TelegramID) -> Optional['Portal']:
return cls._select_one_or_none(and_(cls.c.tgid == tgid, cls.c.tg_receiver == tg_receiver))
@classmethod
def get_by_mxid(cls, mxid: MatrixRoomID) -> Optional['Portal']:
def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
@@ -69,7 +66,7 @@ class Portal(Base):
return cls._select_one_or_none(cls.c.username == username)
@property
def _edit_identity(self):
def _edit_identity(self) -> ClauseElement:
return and_(self.c.tgid == self.tgid, self.c.tg_receiver == self.tg_receiver)
def insert(self) -> None:
+22 -25
View File
@@ -13,31 +13,35 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import expression
from typing import Optional, Iterable
from ..types import MatrixUserID, MatrixRoomID, TelegramID
from .base import Base
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy.sql import expression
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from mautrix.types import UserID
from mautrix.bridge.db import Base
from ..types import TelegramID
class Puppet(Base):
__tablename__ = "puppet"
id = Column(Integer, primary_key=True) # type: TelegramID
custom_mxid = Column(String, nullable=True) # type: Optional[MatrixUserID]
access_token = Column(String, nullable=True)
displayname = Column(String, nullable=True)
displayname_source = Column(Integer, nullable=True) # type: Optional[TelegramID]
username = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
is_bot = Column(Boolean, nullable=True)
matrix_registered = Column(Boolean, nullable=False, server_default=expression.false())
disable_updates = Column(Boolean, nullable=False, server_default=expression.false())
id: TelegramID = Column(Integer, primary_key=True)
custom_mxid: UserID = Column(String, nullable=True)
access_token: str = Column(String, nullable=True)
displayname: str = Column(String, nullable=True)
displayname_source: TelegramID = Column(Integer, nullable=True)
username: str = Column(String, nullable=True)
photo_id: str = Column(String, nullable=True)
is_bot: bool = Column(Boolean, nullable=True)
matrix_registered: bool = Column(Boolean, nullable=False, server_default=expression.false())
disable_updates: bool = Column(Boolean, nullable=False, server_default=expression.false())
@classmethod
def scan(cls, row) -> Optional['Puppet']:
def scan(cls, row: RowProxy) -> Optional['Puppet']:
(id, custom_mxid, access_token, displayname, displayname_source, username, photo_id,
is_bot, matrix_registered, disable_updates) = row
return cls(id=id, custom_mxid=custom_mxid, access_token=access_token,
@@ -45,13 +49,6 @@ class Puppet(Base):
username=username, photo_id=photo_id, is_bot=is_bot,
matrix_registered=matrix_registered, disable_updates=disable_updates)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['Puppet']:
try:
return cls.scan(next(rows))
except StopIteration:
return None
@classmethod
def all_with_custom_mxid(cls) -> Iterable['Puppet']:
rows = cls.db.execute(cls.t.select().where(cls.c.custom_mxid != None))
@@ -63,7 +60,7 @@ class Puppet(Base):
return cls._select_one_or_none(cls.c.id == tgid)
@classmethod
def get_by_custom_mxid(cls, mxid: MatrixUserID) -> Optional['Puppet']:
def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.custom_mxid == mxid)
@classmethod
@@ -75,7 +72,7 @@ class Puppet(Base):
return cls._select_one_or_none(cls.c.displayname == displayname)
@property
def _edit_identity(self):
def _edit_identity(self) -> ClauseElement:
return self.c.id == self.id
def insert(self) -> None:
+22 -22
View File
@@ -13,40 +13,40 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean
from typing import Optional
from mautrix.types import ContentURI
from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean
from sqlalchemy.engine.result import RowProxy
from .base import Base
from mautrix.types import ContentURI
from mautrix.bridge.db import Base
class TelegramFile(Base):
__tablename__ = "telegram_file"
id = Column(String, primary_key=True)
id: str = Column(String, primary_key=True)
mxc: ContentURI = Column(String)
mime_type = Column(String)
was_converted = Column(Boolean)
timestamp = Column(BigInteger)
size = Column(Integer, nullable=True)
width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True)
thumbnail_id = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail = None # type: Optional[TelegramFile]
mime_type: str = Column(String)
was_converted: bool = Column(Boolean)
timestamp: int = Column(BigInteger)
size: int = Column(Integer, nullable=True)
width: int = Column(Integer, nullable=True)
height: int = Column(Integer, nullable=True)
thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail: Optional['TelegramFile'] = None
def scan(cls, row: RowProxy) -> 'TelegramFile':
loc_id, mxc, mime, conv, ts, s, w, h, thumb_id = row
thumb = None
if thumb_id:
thumb = cls.get(thumb_id)
return cls(id=loc_id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
@classmethod
def get(cls, loc_id: str) -> Optional['TelegramFile']:
rows = cls.db.execute(cls.t.select().where(cls.c.id == loc_id))
try:
loc_id, mxc, mime, conv, ts, s, w, h, thumb_id = next(rows)
thumb = None
if thumb_id:
thumb = cls.get(thumb_id)
return cls(id=loc_id, mxc=mxc, mime_type=mime, was_converted=conv, timestamp=ts,
size=s, width=w, height=h, thumbnail_id=thumb_id, thumbnail=thumb)
except StopIteration:
return None
return cls._select_one_or_none(cls.c.id == loc_id)
def insert(self) -> None:
with self.db.begin() as conn:
+26 -29
View File
@@ -13,46 +13,43 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String
from sqlalchemy.engine.result import RowProxy
from typing import Optional, Iterable, Tuple
from ..types import MatrixUserID, TelegramID
from .base import Base
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql.expression import ClauseElement
from mautrix.types import UserID
from mautrix.bridge.db import Base
from ..types import TelegramID
class User(Base):
__tablename__ = "user"
mxid = Column(String, primary_key=True) # type: MatrixUserID
tgid = Column(Integer, nullable=True, unique=True) # type: Optional[TelegramID]
tg_username = Column(String, nullable=True)
tg_phone = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0, nullable=False)
mxid: UserID = Column(String, primary_key=True)
tgid: Optional[TelegramID] = Column(Integer, nullable=True, unique=True)
tg_username: str = Column(String, nullable=True)
tg_phone: str = Column(String, nullable=True)
saved_contacts: int = Column(Integer, default=0, nullable=False)
@classmethod
def _one_or_none(cls, rows: RowProxy) -> Optional['User']:
try:
mxid, tgid, tg_username, tg_phone, saved_contacts = next(rows)
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
except StopIteration:
return None
def scan(cls, row: RowProxy) -> 'User':
mxid, tgid, tg_username, tg_phone, saved_contacts = row
return cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
@classmethod
def all(cls) -> Iterable['User']:
rows = cls.db.execute(cls.t.select())
for row in rows:
mxid, tgid, tg_username, tg_phone, saved_contacts = row
yield cls(mxid=mxid, tgid=tgid, tg_username=tg_username, tg_phone=tg_phone,
saved_contacts=saved_contacts)
return cls._select_all()
@classmethod
def get_by_tgid(cls, tgid: TelegramID) -> Optional['User']:
return cls._select_one_or_none(cls.c.tgid == tgid)
@classmethod
def get_by_mxid(cls, mxid: MatrixUserID) -> Optional['User']:
def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
@@ -60,7 +57,7 @@ class User(Base):
return cls._select_one_or_none(cls.c.tg_username == username)
@property
def _edit_identity(self):
def _edit_identity(self) -> ClauseElement:
return self.c.mxid == self.mxid
def insert(self) -> None:
@@ -112,10 +109,10 @@ class User(Base):
class UserPortal(Base):
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) # type: TelegramID
portal = Column(Integer, primary_key=True) # type: TelegramID
portal_receiver = Column(Integer, primary_key=True) # type: TelegramID
user: TelegramID = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE",
ondelete="CASCADE"), primary_key=True)
portal: TelegramID = Column(Integer, primary_key=True)
portal_receiver: TelegramID = Column(Integer, primary_key=True)
__table_args__ = (ForeignKeyConstraint(("portal", "portal_receiver"),
("portal.tgid", "portal.tg_receiver"),
@@ -125,5 +122,5 @@ class UserPortal(Base):
class Contact(Base):
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True) # type: TelegramID
contact = Column(Integer, ForeignKey("puppet.id"), primary_key=True) # type: TelegramID
user: TelegramID = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
contact: TelegramID = Column(Integer, ForeignKey("puppet.id"), primary_key=True)