Remove remaining traces of ORM

This commit is contained in:
Tulir Asokan
2019-03-16 17:05:58 +02:00
parent 7c82580b4b
commit 7c46bf4b9e
10 changed files with 30 additions and 50 deletions
+5 -10
View File
@@ -23,7 +23,6 @@ import sys
import copy import copy
import signal import signal
from sqlalchemy import orm
import sqlalchemy as sql import sqlalchemy as sql
from mautrix_appservice import AppService from mautrix_appservice import AppService
@@ -73,13 +72,10 @@ log = logging.getLogger("mau.init") # type: logging.Logger
log.debug(f"Initializing mautrix-telegram {__version__}") log.debug(f"Initializing mautrix-telegram {__version__}")
db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db") db_engine = sql.create_engine(config["appservice.database"] or "sqlite:///mautrix-telegram.db")
db_factory = orm.sessionmaker(bind=db_engine)
db_session = orm.scoping.scoped_session(db_factory)
Base.metadata.bind = db_engine Base.metadata.bind = db_engine
session_container = AlchemySessionContainer(engine=db_engine, session=db_session, session_container = AlchemySessionContainer(engine=db_engine, table_base=Base, session=False,
table_base=Base, table_prefix="telethon_", table_prefix="telethon_", manage_tables=False)
manage_tables=False)
session_container.core_mode = True session_container.core_mode = True
try: try:
@@ -102,8 +98,9 @@ appserv = AppService(config["homeserver.address"], config["homeserver.domain"],
aiohttp_params={ aiohttp_params={
"client_max_size": config["appservice.max_body_size"] * mebibyte "client_max_size": config["appservice.max_body_size"] * mebibyte
}) })
bot = init_bot(config)
context = Context(appserv, db_session, config, loop, session_container) context = Context(appserv, config, loop, session_container, bot)
context.mx = MatrixHandler(context)
if config["appservice.public.enabled"]: if config["appservice.public.enabled"]:
public_website = PublicBridgeWebsite(loop) public_website = PublicBridgeWebsite(loop)
@@ -120,8 +117,6 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
start_ts = time() start_ts = time()
init_db(db_engine) init_db(db_engine)
init_abstract_user(context) init_abstract_user(context)
context.bot = init_bot(context)
context.mx = MatrixHandler(context)
init_formatter(context) init_formatter(context)
init_portal(context) init_portal(context)
startup_actions = (init_puppet(context) + startup_actions = (init_puppet(context) +
+3 -8
View File
@@ -20,7 +20,6 @@ import asyncio
import logging import logging
import platform import platform
from sqlalchemy import orm
from telethon.tl.patched import MessageService, Message from telethon.tl.patched import MessageService, Message
from telethon.tl.types import ( from telethon.tl.types import (
Channel, ChannelForbidden, Chat, ChatForbidden, MessageActionChannelMigrateFrom, PeerUser, Channel, ChannelForbidden, Chat, ChatForbidden, MessageActionChannelMigrateFrom, PeerUser,
@@ -56,7 +55,6 @@ class AbstractUser(ABC):
session_container = None # type: AlchemySessionContainer session_container = None # type: AlchemySessionContainer
loop = None # type: asyncio.AbstractEventLoop loop = None # type: asyncio.AbstractEventLoop
log = None # type: logging.Logger log = None # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService az = None # type: AppService
bot = None # type: Bot bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool ignore_incoming_bot_events = True # type: bool
@@ -175,11 +173,8 @@ class AbstractUser(ABC):
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser': async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted or self.connected: if not self.puppet_whitelisted or self.connected:
return self return self
session_count = self.session_container.Session.query.filter( self.log.debug("ensure_started(%s, even_if_no_session=%s)", self.mxid, even_if_no_session)
self.session_container.Session.session_id == self.mxid).count() if even_if_no_session or self.session_container.has_session(self.mxid):
self.log.debug("ensure_started(%s, even_if_no_session=%s, session_count=%s)",
self.mxid, even_if_no_session, session_count)
if even_if_no_session or session_count > 0:
await self.start(delete_unless_authenticated=not even_if_no_session) await self.start(delete_unless_authenticated=not even_if_no_session)
return self return self
@@ -388,7 +383,7 @@ class AbstractUser(ABC):
def init(context: "Context") -> None: def init(context: "Context") -> None:
global config, MAX_DELETIONS global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core AbstractUser.az, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"] AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10) MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
+4 -3
View File
@@ -56,7 +56,7 @@ class Bot(AbstractUser):
self.username = None # type: str self.username = None # type: str
self.is_relaybot = True # type: bool self.is_relaybot = True # type: bool
self.is_bot = True # type: bool self.is_bot = True # type: bool
self.chats = {chat.id: chat.type for chat in BotChat.all()} # type: Dict[int, str] self.chats = {} # type: Dict[int, str]
self.tg_whitelist = [] # type: List[int] self.tg_whitelist = [] # type: List[int]
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"] self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool or False) # type: bool
@@ -74,6 +74,7 @@ class Bot(AbstractUser):
self.tg_whitelist.append(user_id) self.tg_whitelist.append(user_id)
async def start(self, delete_unless_authenticated: bool = False) -> 'Bot': async def start(self, delete_unless_authenticated: bool = False) -> 'Bot':
self.chats = {chat.id: chat.type for chat in BotChat.all()}
await super().start(delete_unless_authenticated) await super().start(delete_unless_authenticated)
if not await self.is_logged_in(): if not await self.is_logged_in():
await self.client.sign_in(bot_token=self.token) await self.client.sign_in(bot_token=self.token)
@@ -280,9 +281,9 @@ class Bot(AbstractUser):
return "bot" return "bot"
def init(context: 'Context') -> Optional[Bot]: def init(cfg: 'Config') -> Optional[Bot]:
global config global config
config = context.config config = cfg
token = config["telegram.bot_token"] token = config["telegram.bot_token"]
if token and not token.lower().startswith("disable"): if token and not token.lower().startswith("disable"):
return Bot(token) return Bot(token)
+1 -1
View File
@@ -328,7 +328,7 @@ class CommandProcessor:
log = logging.getLogger("mau.commands") log = logging.getLogger("mau.commands")
def __init__(self, context: c.Context) -> None: def __init__(self, context: c.Context) -> None:
self.az, self.db, self.config, self.loop, self.tgbot = context.core self.az, self.config, self.loop, self.tgbot = context.core
self.public_website = context.public_website self.public_website = context.public_website
self.command_prefix = self.config["bridge.command_prefix"] self.command_prefix = self.config["bridge.command_prefix"]
+8 -13
View File
@@ -19,8 +19,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio import asyncio
from sqlalchemy.orm import scoped_session
from alchemysession import AlchemySessionContainer from alchemysession import AlchemySessionContainer
from mautrix_appservice import AppService from mautrix_appservice import AppService
@@ -31,20 +29,17 @@ if TYPE_CHECKING:
class Context: class Context:
def __init__(self, az: 'AppService', db: 'scoped_session', config: 'Config', def __init__(self, az: 'AppService', config: 'Config', loop: 'asyncio.AbstractEventLoop',
loop: 'asyncio.AbstractEventLoop', session_container: 'AlchemySessionContainer' session_container: 'AlchemySessionContainer', bot: Optional['Bot']) -> None:
) -> None:
self.az = az # type: AppService self.az = az # type: AppService
self.db = db # type: scoped_session
self.config = config # type: Config self.config = config # type: Config
self.loop = loop # type: asyncio.AbstractEventLoop self.loop = loop # type: asyncio.AbstractEventLoop
self.bot = None # type: Optional[Bot] self.bot = bot # type: Optional[Bot]
self.mx = None # type: MatrixHandler self.mx = None # type: Optional[MatrixHandler]
self.session_container = session_container # type: AlchemySessionContainer self.session_container = session_container # type: AlchemySessionContainer
self.public_website = None # type: PublicBridgeWebsite self.public_website = None # type: Optional[PublicBridgeWebsite]
self.provisioning_api = None # type: ProvisioningAPI self.provisioning_api = None # type: Optional[ProvisioningAPI]
@property @property
def core(self) -> Tuple['AppService', 'scoped_session', 'Config', def core(self) -> Tuple['AppService', 'Config', 'asyncio.AbstractEventLoop', Optional['Bot']]:
'asyncio.AbstractEventLoop', Optional['Bot']]: return self.az, self.config, self.loop, self.bot
return (self.az, self.db, self.config, self.loop, self.bot)
+1 -2
View File
@@ -15,7 +15,6 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean from sqlalchemy import Column, ForeignKey, Integer, BigInteger, String, Boolean
from sqlalchemy.orm import relationship
from typing import Optional from typing import Optional
from .base import Base from .base import Base
@@ -33,7 +32,7 @@ class TelegramFile(Base):
width = Column(Integer, nullable=True) width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True) height = Column(Integer, nullable=True)
thumbnail_id = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True) thumbnail_id = Column("thumbnail", String, ForeignKey("telegram_file.id"), nullable=True)
thumbnail = relationship("TelegramFile", uselist=False) thumbnail = None # type: Optional[TelegramFile]
@classmethod @classmethod
def get(cls, id: str) -> Optional['TelegramFile']: def get(cls, id: str) -> Optional['TelegramFile']:
+1 -1
View File
@@ -32,7 +32,7 @@ class MatrixHandler:
log = logging.getLogger("mau.mx") # type: logging.Logger log = logging.getLogger("mau.mx") # type: logging.Logger
def __init__(self, context: 'Context') -> None: def __init__(self, context: 'Context') -> None:
self.az, self.db, self.config, _, self.tgbot = context.core self.az, self.config, _, self.tgbot = context.core
self.commands = com.CommandProcessor(context) # type: com.CommandProcessor self.commands = com.CommandProcessor(context) # type: com.CommandProcessor
self.previously_typing = [] # type: List[MatrixUserID] self.previously_typing = [] # type: List[MatrixUserID]
+1 -1
View File
@@ -2039,7 +2039,7 @@ class Portal:
def init(context: Context) -> None: def init(context: Context) -> None:
global config global config
Portal.az, _, config, Portal.loop, Portal.bot = context.core Portal.az, config, Portal.loop, Portal.bot = context.core
Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"] Portal.max_initial_member_sync = config["bridge.max_initial_member_sync"]
Portal.sync_channel_members = config["bridge.sync_channel_members"] Portal.sync_channel_members = config["bridge.sync_channel_members"]
Portal.sync_matrix_state = config["bridge.sync_matrix_state"] Portal.sync_matrix_state = config["bridge.sync_matrix_state"]
+4 -9
View File
@@ -14,8 +14,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Awaitable, Coroutine, Dict, List, Iterable, Optional, Pattern, Union, from typing import Awaitable, Any, Dict, List, Iterable, Optional, Pattern, Union, TYPE_CHECKING
TYPE_CHECKING)
from difflib import SequenceMatcher from difflib import SequenceMatcher
from enum import Enum from enum import Enum
from aiohttp import ServerDisconnectedError from aiohttp import ServerDisconnectedError
@@ -23,8 +22,6 @@ import asyncio
import logging import logging
import re import re
from sqlalchemy import orm
from telethon.tl.types import UserProfilePhoto, User, FileLocation, UpdateUserName, PeerUser from telethon.tl.types import UserProfilePhoto, User, FileLocation, UpdateUserName, PeerUser
from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError from mautrix_appservice import AppService, IntentAPI, IntentError, MatrixRequestError
@@ -45,7 +42,6 @@ config = None # type: Config
class Puppet: class Puppet:
log = logging.getLogger("mau.puppet") # type: logging.Logger log = logging.getLogger("mau.puppet") # type: logging.Logger
db = None # type: orm.Session
az = None # type: AppService az = None # type: AppService
mx = None # type: MatrixHandler mx = None # type: MatrixHandler
loop = None # type: asyncio.AbstractEventLoop loop = None # type: asyncio.AbstractEventLoop
@@ -400,8 +396,7 @@ class Puppet:
if create: if create:
puppet = cls(tgid) puppet = cls(tgid)
cls.db.add(puppet.db_instance) puppet.db_instance.insert()
cls.db.commit()
return puppet return puppet
return None return None
@@ -481,9 +476,9 @@ class Puppet:
# endregion # endregion
def init(context: 'Context') -> List[Coroutine]: # [None, None, PuppetError] def init(context: 'Context') -> List[Awaitable[Any]]: # [None, None, PuppetError]
global config global config
Puppet.az, Puppet.db, config, Puppet.loop, _ = context.core Puppet.az, config, Puppet.loop, _ = context.core
Puppet.mx = context.mx Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
+2 -2
View File
@@ -39,8 +39,8 @@ setuptools.setup(
"ruamel.yaml>=0.15.35,<0.16", "ruamel.yaml>=0.15.35,<0.16",
"future-fstrings>=0.4.2", "future-fstrings>=0.4.2",
"python-magic>=0.4.15,<0.5", "python-magic>=0.4.15,<0.5",
"telethon>=1.5.5,<1.6", "telethon>=1.5.5,<1.7",
"telethon-session-sqlalchemy>=0.2.9,<0.3", "telethon-session-sqlalchemy>=0.2.11,<0.3",
], ],
extras_require=extras, extras_require=extras,