Infinite backfill with MSC2716 (#817)
Disabled by default, with non-infinite fallback mode as the default behavior
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from .backfill_queue import Backfill
|
||||
from .bot_chat import BotChat
|
||||
from .disappearing_message import DisappearingMessage
|
||||
from .message import Message
|
||||
@@ -38,6 +39,7 @@ def init(db: Database) -> None:
|
||||
BotChat,
|
||||
PgSession,
|
||||
DisappearingMessage,
|
||||
Backfill,
|
||||
):
|
||||
table.db = db
|
||||
|
||||
@@ -54,4 +56,5 @@ __all__ = [
|
||||
"BotChat",
|
||||
"PgSession",
|
||||
"DisappearingMessage",
|
||||
"Backfill",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
# mautrix-telegram - A Matrix-Telegram puppeting bridge
|
||||
# Copyright (C) 2022 Tulir Asokan, Sumner Evans
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import RoomID, UserID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Backfill:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
queue_id: int | None
|
||||
user_mxid: UserID
|
||||
priority: int
|
||||
portal_tgid: int
|
||||
portal_tg_receiver: int
|
||||
messages_per_batch: int
|
||||
post_batch_delay: int
|
||||
max_batches: int
|
||||
dispatch_time: datetime | None
|
||||
completed_at: datetime | None
|
||||
cooldown_timeout: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
user_mxid: UserID,
|
||||
priority: int,
|
||||
portal_tgid: int,
|
||||
portal_tg_receiver: int,
|
||||
messages_per_batch: int,
|
||||
post_batch_delay: int = 0,
|
||||
max_batches: int = -1,
|
||||
) -> "Backfill":
|
||||
return Backfill(
|
||||
queue_id=None,
|
||||
user_mxid=user_mxid,
|
||||
priority=priority,
|
||||
portal_tgid=portal_tgid,
|
||||
portal_tg_receiver=portal_tg_receiver,
|
||||
messages_per_batch=messages_per_batch,
|
||||
post_batch_delay=post_batch_delay,
|
||||
max_batches=max_batches,
|
||||
dispatch_time=None,
|
||||
completed_at=None,
|
||||
cooldown_timeout=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Backfill | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
columns = [
|
||||
"user_mxid",
|
||||
"priority",
|
||||
"portal_tgid",
|
||||
"portal_tg_receiver",
|
||||
"messages_per_batch",
|
||||
"post_batch_delay",
|
||||
"max_batches",
|
||||
"dispatch_time",
|
||||
"completed_at",
|
||||
"cooldown_timeout",
|
||||
]
|
||||
columns_str = ",".join(columns)
|
||||
|
||||
@classmethod
|
||||
async def get_next(cls, user_mxid: UserID) -> Backfill | None:
|
||||
q = f"""
|
||||
SELECT queue_id, {cls.columns_str}
|
||||
FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND (
|
||||
dispatch_time IS NULL
|
||||
OR (
|
||||
dispatch_time < $2
|
||||
AND completed_at IS NULL
|
||||
)
|
||||
)
|
||||
AND (
|
||||
cooldown_timeout IS NULL
|
||||
OR cooldown_timeout < current_timestamp
|
||||
)
|
||||
ORDER BY priority, queue_id
|
||||
LIMIT 1
|
||||
"""
|
||||
return cls._from_row(
|
||||
await cls.db.fetchrow(q, user_mxid, datetime.now() - timedelta(minutes=15))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls,
|
||||
user_mxid: UserID,
|
||||
portal_tgid: int,
|
||||
portal_tg_receiver: int,
|
||||
) -> Backfill | None:
|
||||
q = f"""
|
||||
SELECT queue_id, {cls.columns_str}
|
||||
FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND portal_tgid=$2
|
||||
AND portal_tg_receiver=$3
|
||||
ORDER BY priority, queue_id
|
||||
LIMIT 1
|
||||
"""
|
||||
return cls._from_row(await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver))
|
||||
|
||||
@classmethod
|
||||
async def delete_all(cls, user_mxid: UserID) -> None:
|
||||
await cls.db.execute("DELETE FROM backfill_queue WHERE user_mxid=$1", user_mxid)
|
||||
|
||||
@classmethod
|
||||
async def delete_for_portal(cls, tgid: int, tg_receiver: int) -> None:
|
||||
q = "DELETE FROM backfill_queue WHERE portal_tgid=$1 AND portal_tg_receiver=$2"
|
||||
await cls.db.execute(q, tgid, tg_receiver)
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = f"""
|
||||
INSERT INTO backfill_queue ({self.columns_str})
|
||||
VALUES ({','.join(f'${i+1}' for i in range(len(self.columns)))})
|
||||
RETURNING queue_id
|
||||
"""
|
||||
row = await self.db.fetchrow(
|
||||
q,
|
||||
self.user_mxid,
|
||||
self.priority,
|
||||
self.portal_tgid,
|
||||
self.portal_tg_receiver,
|
||||
self.messages_per_batch,
|
||||
self.post_batch_delay,
|
||||
self.max_batches,
|
||||
self.dispatch_time,
|
||||
self.completed_at,
|
||||
self.cooldown_timeout,
|
||||
)
|
||||
self.queue_id = row["queue_id"]
|
||||
|
||||
async def mark_dispatched(self) -> None:
|
||||
q = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
|
||||
await self.db.execute(q, datetime.now(), self.queue_id)
|
||||
|
||||
async def mark_done(self) -> None:
|
||||
q = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2"
|
||||
await self.db.execute(q, datetime.now(), self.queue_id)
|
||||
|
||||
async def set_cooldown_timeout(self, timeout) -> None:
|
||||
"""
|
||||
Set the backfill request to cooldown for ``timeout`` seconds.
|
||||
"""
|
||||
q = "UPDATE backfill_queue SET cooldown_timeout=$1 WHERE queue_id=$2"
|
||||
await self.db.execute(q, datetime.now() + timedelta(seconds=timeout), self.queue_id)
|
||||
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
import attr
|
||||
|
||||
from mautrix.types import EventID, RoomID, UserID
|
||||
from mautrix.util.async_db import Database, Scheme
|
||||
@@ -122,6 +123,14 @@ class Message:
|
||||
)
|
||||
return cls._from_row(await cls.db.fetchrow(q, mx_room, tg_space))
|
||||
|
||||
@classmethod
|
||||
async def find_first(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None:
|
||||
q = (
|
||||
f"SELECT {cls.columns} FROM message WHERE mx_room=$1 AND tg_space=$2 "
|
||||
f"ORDER BY tgid ASC LIMIT 1"
|
||||
)
|
||||
return cls._from_row(await cls.db.fetchrow(q, mx_room, tg_space))
|
||||
|
||||
@classmethod
|
||||
async def delete_all(cls, mx_room: RoomID) -> None:
|
||||
await cls.db.execute("DELETE FROM message WHERE mx_room=$1", mx_room)
|
||||
@@ -173,6 +182,23 @@ class Message:
|
||||
q = "DELETE FROM message WHERE mxid=$1 AND mx_room=$2"
|
||||
await cls.db.execute(q, temp_mxid, mx_room)
|
||||
|
||||
@classmethod
|
||||
async def bulk_insert(cls, messages: list[Message]) -> None:
|
||||
columns = cls.columns.split(", ")
|
||||
records = [attr.astuple(message) for message in messages]
|
||||
async with cls.db.acquire() as conn, conn.transaction():
|
||||
if cls.db.scheme == Scheme.POSTGRES:
|
||||
await conn.copy_records_to_table("message", records=records, columns=columns)
|
||||
else:
|
||||
await conn.executemany(cls._insert_query, records)
|
||||
|
||||
_insert_query: ClassVar[
|
||||
str
|
||||
] = """
|
||||
INSERT INTO message (mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash, sender_mxid, sender)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
"""
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (
|
||||
@@ -188,13 +214,7 @@ class Message:
|
||||
)
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = """
|
||||
INSERT INTO message (
|
||||
mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash,
|
||||
sender_mxid, sender
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
"""
|
||||
await self.db.execute(q, *self._values)
|
||||
await self.db.execute(self._insert_query, *self._values)
|
||||
|
||||
async def delete(self) -> None:
|
||||
q = "DELETE FROM message WHERE mxid=$1 AND mx_room=$2 AND tg_space=$3"
|
||||
|
||||
Reference in New Issue
Block a user