Add locking to client connect calls
This commit is contained in:
@@ -211,7 +211,7 @@ class AbstractUser(ABC):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
|
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
|
||||||
if not self.puppet_whitelisted or self.connected:
|
if self.connected:
|
||||||
return self
|
return self
|
||||||
if even_if_no_session or self.session_container.has_session(self.mxid):
|
if even_if_no_session or self.session_container.has_session(self.mxid):
|
||||||
self.log.debug("Starting client due to ensure_started"
|
self.log.debug("Starting client due to ensure_started"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any,
|
from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any, cast,
|
||||||
TYPE_CHECKING)
|
TYPE_CHECKING)
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -54,6 +54,7 @@ class User(AbstractUser):
|
|||||||
command_status: Optional[Dict[str, Any]]
|
command_status: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
_db_instance: Optional[DBUser]
|
_db_instance: Optional[DBUser]
|
||||||
|
_ensure_started_lock: asyncio.Lock
|
||||||
|
|
||||||
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
|
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
|
||||||
username: Optional[str] = None, phone: Optional[str] = None,
|
username: Optional[str] = None, phone: Optional[str] = None,
|
||||||
@@ -73,6 +74,7 @@ class User(AbstractUser):
|
|||||||
self.portals = {}
|
self.portals = {}
|
||||||
self.db_portals = db_portals or []
|
self.db_portals = db_portals or []
|
||||||
self._db_instance = db_instance
|
self._db_instance = db_instance
|
||||||
|
self._ensure_started_lock = asyncio.Lock()
|
||||||
|
|
||||||
self.command_status = None
|
self.command_status = None
|
||||||
|
|
||||||
@@ -172,8 +174,11 @@ class User(AbstractUser):
|
|||||||
# endregion
|
# endregion
|
||||||
# region Telegram connection management
|
# region Telegram connection management
|
||||||
|
|
||||||
def ensure_started(self, even_if_no_session=False) -> Awaitable['User']:
|
async def ensure_started(self, even_if_no_session=False) -> 'User':
|
||||||
return super().ensure_started(even_if_no_session)
|
if not self.puppet_whitelisted or self.connected:
|
||||||
|
return self
|
||||||
|
async with self._ensure_started_lock:
|
||||||
|
return cast(User, await super().ensure_started(even_if_no_session))
|
||||||
|
|
||||||
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
|
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
|
||||||
await super().start()
|
await super().start()
|
||||||
|
|||||||
Reference in New Issue
Block a user