Merge pull request #239 from tulir/sqlalchemy-core

Port Message table to SQLAlchemy Core
This commit is contained in:
Tulir Asokan
2018-10-21 00:32:14 +03:00
committed by GitHub
6 changed files with 102 additions and 53 deletions
+1 -1
View File
@@ -113,7 +113,7 @@ if config["appservice.provisioning.enabled"]:
context.provisioning_api = provisioning_api context.provisioning_api = provisioning_api
with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start: with appserv.run(config["appservice.hostname"], config["appservice.port"]) as start:
init_db(db_session) init_db(db_session, db_engine)
init_abstract_user(context) init_abstract_user(context)
context.bot = init_bot(context) context.bot = init_bot(context)
context.mx = MatrixHandler(context) context.mx = MatrixHandler(context)
+6 -7
View File
@@ -230,7 +230,7 @@ class AbstractUser(ABC):
return return
# We check that these are user read receipts, so tg_space is always the user ID. # We check that these are user read receipts, so tg_space is always the user ID.
message = DBMessage.query.get((update.max_id, self.tgid)) message = DBMessage.get_by_tgid(update.max_id, self.tgid)
if not message: if not message:
return return
@@ -323,12 +323,11 @@ class AbstractUser(ABC):
return return
for message in update.messages: for message in update.messages:
message = DBMessage.query.get((message, self.tgid)) message = DBMessage.get_by_tgid(TelegramID(message), self.tgid)
if not message: if not message:
continue continue
self.db.delete(message) message.delete()
number_left = DBMessage.query.filter(DBMessage.mxid == message.mxid, number_left = DBMessage.count_spaces_by_mxid(message.mxid, message.mx_room)
DBMessage.mx_room == message.mx_room).count()
if number_left == 0: if number_left == 0:
portal = po.Portal.get_by_mxid(message.mx_room) portal = po.Portal.get_by_mxid(message.mx_room)
await self._try_redact(portal, message) await self._try_redact(portal, message)
@@ -343,10 +342,10 @@ class AbstractUser(ABC):
return return
for message in update.messages: for message in update.messages:
message = DBMessage.query.get((message, portal.tgid)) message = DBMessage.get_by_tgid(TelegramID(message), portal.tgid)
if not message: if not message:
continue continue
self.db.delete(message) message.delete()
await self._try_redact(portal, message) await self._try_redact(portal, message)
self.db.commit() self.db.commit()
+72 -4
View File
@@ -15,9 +15,12 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer, from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text) BigInteger, String, Boolean, Text, Table,
and_, func, select)
from sqlalchemy.engine import Engine, RowProxy
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship, Query from sqlalchemy.orm import relationship, Query
from sqlalchemy.sql.base import ImmutableColumnCollection
from typing import Dict, Optional, List from typing import Dict, Optional, List
import json import json
@@ -49,7 +52,9 @@ class Portal(Base):
class Message(Base): class Message(Base):
query = None # type: Query db = None # type: Engine
t = None # type: Table
c = None # type: ImmutableColumnCollection
__tablename__ = "message" __tablename__ = "message"
mxid = Column(String) # type: MatrixEventID mxid = Column(String) # type: MatrixEventID
@@ -59,6 +64,67 @@ class Message(Base):
__table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),) __table_args__ = (UniqueConstraint("mxid", "mx_room", "tg_space", name="_mx_id_room"),)
@staticmethod
def _one_or_none(rows: RowProxy) -> Optional['Message']:
try:
mxid, mx_room, tgid, tg_space = next(rows)
return Message(mxid=mxid, mx_room=mx_room, tgid=tgid, tg_space=tg_space)
except StopIteration:
return None
@staticmethod
def _all(rows: RowProxy) -> List['Message']:
return [Message(mxid=row[0], mx_room=row[1], tgid=row[2], tg_space=row[3])
for row in rows]
@classmethod
def get_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID) -> Optional['Message']:
rows = cls.db.execute(cls.t.select()
.where(and_(cls.c.tgid == tgid, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def count_spaces_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID) -> int:
rows = cls.db.execute(select([func.count(cls.c.tg_space)])
.where(and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room)))
try:
count, = next(rows)
return count
except StopIteration:
return 0
@classmethod
def get_by_mxid(cls, mxid: MatrixEventID, mx_room: MatrixRoomID, tg_space: TelegramID
) -> Optional['Message']:
rows = cls.db.execute(cls.t.select().where(
and_(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space)))
return cls._one_or_none(rows)
@classmethod
def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.tgid == s_tgid, cls.c.tg_space == s_tg_space))
.values(**values))
@classmethod
def update_by_mxid(cls, s_mxid: MatrixEventID, s_mx_room: MatrixRoomID, **values) -> None:
cls.db.execute(cls.t.update()
.where(and_(cls.c.mxid == s_mxid, cls.c.mx_room == s_mx_room))
.values(**values))
def update(self, **values) -> None:
for key, value in values.items():
setattr(self, key, value)
self.update_by_tgid(self.tgid, self.tg_space, **values)
def delete(self) -> None:
self.db.execute(self.t.delete().where(
and_(self.c.tgid == self.tgid, self.c.tg_space == self.tg_space)))
def insert(self) -> None:
self.db.execute(self.t.insert().values(mxid=self.mxid, mx_room=self.mx_room, tgid=self.tgid,
tg_space=self.tg_space))
class UserPortal(Base): class UserPortal(Base):
query = None # type: Query query = None # type: Query
@@ -178,9 +244,11 @@ class TelegramFile(Base):
thumbnail = relationship("TelegramFile", uselist=False) thumbnail = relationship("TelegramFile", uselist=False)
def init(db_session) -> None: def init(db_session, db_engine) -> None:
Portal.query = db_session.query_property() Portal.query = db_session.query_property()
Message.query = db_session.query_property() Message.db = db_engine
Message.t = Message.__table__
Message.c = Message.t.c
UserPortal.query = db_session.query_property() UserPortal.query = db_session.query_property()
User.query = db_session.query_property() User.query = db_session.query_property()
Puppet.query = db_session.query_property() Puppet.query = db_session.query_property()
@@ -105,9 +105,7 @@ def matrix_reply_to_telegram(content: Dict[str, Any], tg_space: TelegramID,
pass pass
content["body"] = trim_reply_fallback_text(content["body"]) content["body"] = trim_reply_fallback_text(content["body"])
message = DBMessage.query.filter(DBMessage.mxid == event_id, message = DBMessage.get_by_mxid(event_id, room_id, tg_space)
DBMessage.tg_space == tg_space,
DBMessage.mx_room == room_id).one_or_none()
if message: if message:
return message.tgid return message.tgid
except KeyError: except KeyError:
+3 -3
View File
@@ -54,7 +54,7 @@ def telegram_reply_to_matrix(evt: Message, source: 'AbstractUser') -> Dict:
space = (evt.to_id.channel_id space = (evt.to_id.channel_id
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid) else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space)) msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if msg: if msg:
return { return {
"m.in_reply_to": { "m.in_reply_to": {
@@ -124,7 +124,7 @@ async def _add_reply_header(source: "AbstractUser", text: str, html: str, evt: M
if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel) if isinstance(evt, Message) and isinstance(evt.to_id, PeerChannel)
else source.tgid) else source.tgid)
msg = DBMessage.query.get((evt.reply_to_msg_id, space)) msg = DBMessage.get_by_tgid(evt.reply_to_msg_id, space)
if not msg: if not msg:
return text, html return text, html
@@ -325,7 +325,7 @@ def _parse_url(html: List[str], entity_text: str, url: str) -> bool:
portal = po.Portal.find_by_username(group) portal = po.Portal.find_by_username(group)
if portal: if portal:
message = DBMessage.query.get((msgid, portal.tgid)) message = DBMessage.get_by_tgid(TelegramID(msgid), portal.tgid)
if message: if message:
url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}" url = f"https://matrix.to/#/{portal.mxid}/{message.mxid}"
+19 -35
View File
@@ -772,9 +772,7 @@ class Portal:
if user.is_bot: if user.is_bot:
return return
space = self.tgid if self.peer_type == "channel" else user.tgid space = self.tgid if self.peer_type == "channel" else user.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id, message = DBMessage.get_by_mxid(event_id, self.mxid, space)
DBMessage.mx_room == self.mxid,
DBMessage.tg_space == space).one_or_none()
if not message: if not message:
return return
if self.peer_type == "channel": if self.peer_type == "channel":
@@ -959,12 +957,11 @@ class Portal:
response: TypeMessage) -> None: response: TypeMessage) -> None:
self.log.debug("Handled Matrix message: %s", response) self.log.debug("Handled Matrix message: %s", response)
self.is_duplicate(response, (event_id, space)) self.is_duplicate(response, (event_id, space))
self.db.add(DBMessage( DBMessage(
tgid=response.id, tgid=response.id,
tg_space=space, tg_space=space,
mx_room=self.mxid, mx_room=self.mxid,
mxid=event_id)) mxid=event_id).insert()
self.db.commit()
async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any], async def handle_matrix_message(self, sender: 'u.User', message: Dict[str, Any],
event_id: MatrixEventID) -> None: event_id: MatrixEventID) -> None:
@@ -1009,9 +1006,10 @@ class Portal:
if not pinned_message: if not pinned_message:
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0)) await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=0))
else: else:
message = DBMessage.query.filter(DBMessage.mxid == pinned_message, message = DBMessage.get_by_mxid(pinned_message, self.mxid, self.tgid)
DBMessage.tg_space == self.tgid, if message is None:
DBMessage.mx_room == self.mxid).one_or_none() self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}")
return
await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid)) await sender.client(UpdatePinnedMessageRequest(channel=self.peer, id=message.tgid))
except ChatNotModifiedError: except ChatNotModifiedError:
pass pass
@@ -1019,9 +1017,7 @@ class Portal:
async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None: async def handle_matrix_deletion(self, deleter: 'u.User', event_id: MatrixEventID) -> None:
real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot real_deleter = deleter if not await deleter.needs_relaybot(self) else self.bot
space = self.tgid if self.peer_type == "channel" else real_deleter.tgid space = self.tgid if self.peer_type == "channel" else real_deleter.tgid
message = DBMessage.query.filter(DBMessage.mxid == event_id, message = DBMessage.get_by_mxid(event_id, self.mxid, space)
DBMessage.tg_space == space,
DBMessage.mx_room == self.mxid).one_or_none()
if not message: if not message:
return return
await real_deleter.client.delete_messages(self.peer, [message.tgid]) await real_deleter.client.delete_messages(self.peer, [message.tgid])
@@ -1413,10 +1409,9 @@ class Portal:
if duplicate_found: if duplicate_found:
mxid, other_tg_space = duplicate_found mxid, other_tg_space = duplicate_found
if tg_space != other_tg_space: if tg_space != other_tg_space:
msg = DBMessage.query.get((evt.id, tg_space)) DBMessage.update_by_tgid(evt.id, tg_space,
msg.mxid = mxid mxid=mxid,
msg.mx_room = self.mxid mx_room=self.mxid)
self.db.commit()
return return
evt.reply_to_msg_id = evt.id evt.reply_to_msg_id = evt.id
@@ -1429,19 +1424,14 @@ class Portal:
mxid = response["event_id"] mxid = response["event_id"]
msg = DBMessage.query.get((evt.id, tg_space)) msg = DBMessage.get_by_tgid(evt.id, tg_space)
if not msg: if not msg:
self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) " self.log.info(f"Didn't find edited message {evt.id}@{tg_space} (src {source.tgid}) "
"in database.") "in database.")
# Oh crap # Oh crap
return return
msg.mxid = mxid msg.update(mxid=mxid, mx_room=self.mxid)
msg.mx_room = self.mxid DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
DBMessage.query \
.filter(DBMessage.mx_room == self.mxid,
DBMessage.mxid == temporary_identifier) \
.update({"mxid": mxid})
self.db.commit()
async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet, async def handle_telegram_message(self, source: "AbstractUser", sender: p.Puppet,
evt: Message) -> None: evt: Message) -> None:
@@ -1463,13 +1453,11 @@ class Portal:
self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) " self.log.debug(f"Ignoring message {evt.id}@{tg_space} (src {source.tgid}) "
f"as it was already handled (in space {other_tg_space})") f"as it was already handled (in space {other_tg_space})")
if tg_space != other_tg_space: if tg_space != other_tg_space:
self.db.add( DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space))
self.db.commit()
return return
if self.dedup_pre_db_check and self.peer_type == "channel": if self.dedup_pre_db_check and self.peer_type == "channel":
msg = DBMessage.query.get((evt.id, tg_space)) msg = DBMessage.get_by_tgid(evt.id, tg_space)
if msg: if msg:
self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already" self.log.debug(f"Ignoring message {evt.id} (src {source.tgid}) as it was already"
f"handled into {msg.mxid}. This duplicate was catched in the db " f"handled into {msg.mxid}. This duplicate was catched in the db "
@@ -1523,12 +1511,8 @@ class Portal:
self.log.debug("Handled Telegram message: %s", evt) self.log.debug("Handled Telegram message: %s", evt)
try: try:
self.db.add(DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space)) DBMessage(tgid=evt.id, mx_room=self.mxid, mxid=mxid, tg_space=tg_space).insert()
self.db.commit() DBMessage.update_by_mxid(temporary_identifier, self.mxid, mxid=mxid)
DBMessage.query \
.filter(DBMessage.mx_room == self.mxid,
DBMessage.mxid == temporary_identifier) \
.update({"mxid": mxid})
except FlushError as e: except FlushError as e:
self.log.exception(f"{e.__class__.__name__} while saving message mapping. " self.log.exception(f"{e.__class__.__name__} while saving message mapping. "
"This might mean that an update was handled after it left the " "This might mean that an update was handled after it left the "
@@ -1610,7 +1594,7 @@ class Portal:
self._temp_pinned_message_id = None self._temp_pinned_message_id = None
self._temp_pinned_message_sender = None self._temp_pinned_message_sender = None
message = DBMessage.query.get((msg_id, self.tgid)) message = DBMessage.get_by_tgid(msg_id, self.tgid)
if message: if message:
await intent.set_pinned_messages(self.mxid, [message.mxid]) await intent.set_pinned_messages(self.mxid, [message.mxid])
else: else: