Add option for parallel streamed file transfer

This commit is contained in:
Tulir Asokan
2019-10-27 01:12:15 +03:00
parent 6cb8e007aa
commit 574312d7c5
7 changed files with 224 additions and 31 deletions
+4
View File
@@ -163,6 +163,10 @@ bridge:
image_as_file_size: 10 image_as_file_size: 10
# Maximum size of Telegram documents in megabytes to bridge. # Maximum size of Telegram documents in megabytes to bridge.
max_document_size: 100 max_document_size: 100
# Enable experimental parallel file transfer, which makes uploads/downloads much faster by
# streaming from/to Matrix and using many connections for Telegram.
# Note that generating HQ thumbnails for videos is not possible with streamed transfers.
parallel_file_transfer: false
# Whether or not created rooms should have federation enabled. # Whether or not created rooms should have federation enabled.
# If false, created portal rooms will never be federated. # If false, created portal rooms will never be federated.
federate_rooms: true federate_rooms: true
+1
View File
@@ -101,6 +101,7 @@ class Config(BaseBridgeConfig):
copy("bridge.inline_images") copy("bridge.inline_images")
copy("bridge.image_as_file_size") copy("bridge.image_as_file_size")
copy("bridge.max_document_size") copy("bridge.max_document_size")
copy("bridge.parallel_file_transfer")
copy("bridge.federate_rooms") copy("bridge.federate_rooms")
copy("bridge.bot_messages_as_notices") copy("bridge.bot_messages_as_notices")
+3 -3
View File
@@ -30,9 +30,9 @@ class TelegramFile(Base):
mime_type: str = Column(String) mime_type: str = Column(String)
was_converted: bool = Column(Boolean) was_converted: bool = Column(Boolean)
timestamp: int = Column(BigInteger) timestamp: int = Column(BigInteger)
size: int = Column(Integer, nullable=True) size: Optional[int] = Column(Integer, nullable=True)
width: int = Column(Integer, nullable=True) width: Optional[int] = Column(Integer, nullable=True)
height: int = Column(Integer, nullable=True) height: Optional[int] = Column(Integer, nullable=True)
thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True) thumbnail_id: str = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail: Optional['TelegramFile'] = None thumbnail: Optional['TelegramFile'] = None
+3 -1
View File
@@ -181,8 +181,10 @@ class PortalTelegram(BasePortal, ABC):
self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}") self.log.debug(f"Unsupported thumbnail type {type(thumb_size)}")
thumb_loc = None thumb_loc = None
thumb_size = None thumb_size = None
parallel_id = source.tgid if config["bridge.parallel_file_transfer"] else None
file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc, file = await util.transfer_file_to_matrix(source.client, intent, document, thumb_loc,
is_sticker=attrs.is_sticker) is_sticker=attrs.is_sticker, filename=attrs.name,
parallel_id=parallel_id)
if not file: if not file:
return None return None
+37 -26
View File
@@ -33,6 +33,7 @@ from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile from ..db import TelegramFile as DBTelegramFile
from ..util import sane_mimetypes from ..util import sane_mimetypes
from .parallel_file_transfer import parallel_transfer_to_matrix
try: try:
from PIL import Image from PIL import Image
@@ -126,7 +127,7 @@ async def transfer_thumbnail_to_matrix(client: MautrixTelegramClient, intent: In
return db_file return db_file
video_ext = sane_mimetypes.guess_extension(mime) video_ext = sane_mimetypes.guess_extension(mime)
if VideoFileClip and video_ext: if VideoFileClip and video_ext and video:
try: try:
file, width, height = _read_video_thumbnail(video, video_ext, frame_ext="png") file, width, height = _read_video_thumbnail(video, video_ext, frame_ext="png")
except OSError: except OSError:
@@ -158,7 +159,8 @@ TypeThumbnail = Optional[Union[TypeLocation, TypePhotoSize]]
async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
location: TypeLocation, thumbnail: TypeThumbnail = None, location: TypeLocation, thumbnail: TypeThumbnail = None,
is_sticker: bool = False) -> Optional[DBTelegramFile]: is_sticker: bool = False, filename: Optional[str] = None,
parallel_id: Optional[int] = None) -> Optional[DBTelegramFile]:
location_id = _location_to_id(location) location_id = _location_to_id(location)
if not location_id: if not location_id:
return None return None
@@ -174,43 +176,52 @@ async def transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentA
transfer_locks[location_id] = lock transfer_locks[location_id] = lock
async with lock: async with lock:
return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location, return await _unlocked_transfer_file_to_matrix(client, intent, location_id, location,
thumbnail, is_sticker) thumbnail, is_sticker, filename,
parallel_id)
async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI, async def _unlocked_transfer_file_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation, loc_id: str, location: TypeLocation,
thumbnail: TypeThumbnail, is_sticker: bool thumbnail: TypeThumbnail, is_sticker: bool,
filename: Optional[str],
parallel_id: Optional[int] = None
) -> Optional[DBTelegramFile]: ) -> Optional[DBTelegramFile]:
db_file = DBTelegramFile.get(loc_id) db_file = DBTelegramFile.get(loc_id)
if db_file: if db_file:
return db_file return db_file
try: if parallel_id and isinstance(location, Document):
file = await client.download_file(location) db_file = await parallel_transfer_to_matrix(client, intent, loc_id, location, filename,
except (LocationInvalidError, FileIdInvalidError): parallel_id)
return None mime_type = location.mime_type
except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e: file = None
log.exception(f"{e.__class__.__name__} while downloading a file.") else:
return None try:
file = await client.download_file(location)
except (LocationInvalidError, FileIdInvalidError):
return None
except (AuthBytesInvalidError, AuthKeyInvalidError, SecurityError) as e:
log.exception(f"{e.__class__.__name__} while downloading a file.")
return None
width, height = None, None width, height = None, None
mime_type = magic.from_buffer(file, mime=True) mime_type = magic.from_buffer(file, mime=True)
image_converted = False image_converted = False
if mime_type == "image/webp": if mime_type == "image/webp":
new_mime_type, file, width, height = convert_image( new_mime_type, file, width, height = convert_image(
file, source_mime="image/webp", target_type="png", file, source_mime="image/webp", target_type="png",
thumbnail_to=(256, 256) if is_sticker else None) thumbnail_to=(256, 256) if is_sticker else None)
image_converted = new_mime_type != mime_type image_converted = new_mime_type != mime_type
mime_type = new_mime_type mime_type = new_mime_type
thumbnail = None thumbnail = None
content_uri = await intent.upload_media(file, mime_type) content_uri = await intent.upload_media(file, mime_type)
db_file = DBTelegramFile(id=loc_id, mxc=content_uri, db_file = DBTelegramFile(id=loc_id, mxc=content_uri,
mime_type=mime_type, was_converted=image_converted, mime_type=mime_type, was_converted=image_converted,
timestamp=int(time.time()), size=len(file), timestamp=int(time.time()), size=len(file),
width=width, height=height) width=width, height=height)
if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"): if thumbnail and (mime_type.startswith("video/") or mime_type == "image/gif"):
if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)): if isinstance(thumbnail, (PhotoSize, PhotoCachedSize)):
thumbnail = thumbnail.location thumbnail = thumbnail.location
@@ -0,0 +1,175 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict
from collections import defaultdict
import asyncio
import logging
import time
import math
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
InputPhotoFileLocation, InputPeerPhotoFileLocation)
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import GetFileRequest
from telethon.network import MTProtoSender
from telethon.crypto import AuthKey
from telethon import utils
from mautrix.appservice import IntentAPI
from ..tgclient import MautrixTelegramClient
from ..db import TelegramFile as DBTelegramFile
log: logging.Logger = logging.getLogger("mau.util")
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
InputFileLocation, InputPhotoFileLocation]
class Sender:
sender: MTProtoSender
request: GetFileRequest
remaining: int
stride: int
def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
stride: int, count: int) -> None:
log.debug(f"Creating sender with {offset=} {limit=} {stride=} {count=}")
self.sender = sender
self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride
self.remaining = count
async def next(self) -> Optional[bytes]:
if not self.remaining:
return None
log.debug(f"Sending {self.request!s}")
result = await self.sender.send(self.request)
self.remaining -= 1
self.request.offset += self.stride
return result.bytes
def disconnect(self) -> Awaitable[None]:
return self.sender.disconnect()
class ParallelDownloader:
client: MautrixTelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Sender]]
auth_key: AuthKey
def __init__(self, client: MautrixTelegramClient, dc_id: int) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id
self.exported = dc_id and self.client.session.dc_id != dc_id
self.auth_key = self.client.session.auth_key if not self.exported else None
self.senders = None
async def _init(self, connections: int, file: TypeLocation, part_count: int, part_size: int
) -> None:
minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int:
nonlocal remainder
if remainder > 0:
remainder -= 1
return minimum + 1
return minimum
self.senders = [
await self._create_sender(file, 0, part_size, connections * part_size,
get_part_count()),
*await asyncio.gather(*[
self._create_sender(file, i, part_size, connections * part_size, get_part_count())
for i in range(1, connections)
])
]
async def _cleanup(self) -> None:
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
self.senders = None
async def _create_sender(self, file: TypeLocation, index: int, part_size: int, stride: int,
part_count: int) -> Sender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log)
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
loop=self.loop, loggers=self.client._log,
proxy=self.client._proxy))
if not self.auth_key:
log.debug(f"Exporting auth to DC {self.dc_id}")
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
req = self.client._init_with(ImportAuthorizationRequest(
id=auth.id, bytes=auth.bytes
))
await sender.send(req)
self.auth_key = sender.auth_key
return Sender(sender, file, index * part_size, part_size, stride, part_count)
@staticmethod
def _get_connection_count(file_size: int, max_count: int = 20,
full_size: int = 100 * 1024 * 1024) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def download(self, file: TypeLocation, file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = math.ceil(file_size / part_size)
log.debug("Starting parallel download: "
f"{connection_count} {part_size} {part_count} {file!s}")
await self._init(connection_count, file, part_count, part_size)
part = 0
while part < part_count:
tasks = []
for sender in self.senders:
tasks.append(self.loop.create_task(sender.next()))
for task in tasks:
data = await task
if not data:
break
yield data
part += 1
log.debug(f"Part {part} downloaded")
log.debug("Parallel download finished, cleaning up connections")
await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
async def parallel_transfer_to_matrix(client: MautrixTelegramClient, intent: IntentAPI,
loc_id: str, location: TypeLocation, filename: str,
parallel_id: int) -> DBTelegramFile:
size = location.size
mime_type = location.mime_type
dc_id, location = utils.get_input_location(location)
# We lock the transfers because telegram has connection count limits
async with parallel_transfer_locks[parallel_id]:
downloader = ParallelDownloader(client, dc_id)
content_uri = await intent.upload_media(downloader.download(location, size),
mime_type=mime_type, filename=filename, size=size)
return DBTelegramFile(id=loc_id, mxc=content_uri, mime_type=mime_type,
was_converted=False, timestamp=int(time.time()), size=size,
width=None, height=None)
+1 -1
View File
@@ -32,7 +32,7 @@ setuptools.setup(
install_requires=[ install_requires=[
"aiohttp>=3.0.1,<4", "aiohttp>=3.0.1,<4",
"mautrix>=0.4.0.dev71,<0.5", "mautrix>=0.4.0.dev74,<0.5",
"SQLAlchemy>=1.2.3,<2", "SQLAlchemy>=1.2.3,<2",
"alembic>=1.0.0,<2", "alembic>=1.0.0,<2",
"commonmark>=0.8.1,<0.10", "commonmark>=0.8.1,<0.10",