Fix incorrectly case sensitive username finding in db. Fixes #384

This commit is contained in:
Tulir Asokan
2019-11-30 15:21:47 +02:00
parent 25d7087d07
commit 91e6a73f33
6 changed files with 15 additions and 9 deletions
+2 -2
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Boolean, Text from sqlalchemy import Column, Integer, String, Boolean, Text, func
from mautrix.types import RoomID from mautrix.types import RoomID
from mautrix.util.db import Base from mautrix.util.db import Base
@@ -53,4 +53,4 @@ class Portal(Base):
@classmethod @classmethod
def get_by_username(cls, username: str) -> Optional['Portal']: def get_by_username(cls, username: str) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.username == username) return cls._select_one_or_none(func.lower(cls.c.username) == username)
+2 -2
View File
@@ -16,7 +16,7 @@
from typing import Optional, Iterable from typing import Optional, Iterable
from sqlalchemy import Column, Integer, String, Boolean from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy.sql import expression from sqlalchemy.sql import expression, func
from mautrix.types import UserID, SyncToken from mautrix.types import UserID, SyncToken
from mautrix.util.db import Base from mautrix.util.db import Base
@@ -53,7 +53,7 @@ class Puppet(Base):
@classmethod @classmethod
def get_by_username(cls, username: str) -> Optional['Puppet']: def get_by_username(cls, username: str) -> Optional['Puppet']:
return cls._select_one_or_none(cls.c.username == username) return cls._select_one_or_none(func.lowercase(cls.c.username) == username)
@classmethod @classmethod
def get_by_displayname(cls, displayname: str) -> Optional['Puppet']: def get_by_displayname(cls, displayname: str) -> Optional['Puppet']:
+2 -2
View File
@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Iterable, Tuple from typing import Optional, Iterable, Tuple
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, String, func
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.db import Base from mautrix.util.db import Base
@@ -46,7 +46,7 @@ class User(Base):
@classmethod @classmethod
def get_by_username(cls, username: str) -> Optional['User']: def get_by_username(cls, username: str) -> Optional['User']:
return cls._select_one_or_none(cls.c.tg_username == username) return cls._select_one_or_none(func.lower(cls.c.tg_username) == username)
@property @property
def contacts(self) -> Iterable[TelegramID]: def contacts(self) -> Iterable[TelegramID]:
+3 -1
View File
@@ -354,8 +354,10 @@ class BasePortal(ABC):
if not username: if not username:
return None return None
username = username.lower()
for _, portal in cls.by_tgid.items(): for _, portal in cls.by_tgid.items():
if portal.username and portal.username.lower() == username.lower(): if portal.username and portal.username.lower() == username:
return portal return portal
dbportal = DBPortal.get_by_username(username) dbportal = DBPortal.get_by_username(username)
+3 -1
View File
@@ -384,8 +384,10 @@ class Puppet(CustomPuppetMixin):
if not username: if not username:
return None return None
username = username.lower()
for _, puppet in cls.cache.items(): for _, puppet in cls.cache.items():
if puppet.username and puppet.username.lower() == username.lower(): if puppet.username and puppet.username.lower() == username:
return puppet return puppet
dbpuppet = DBPuppet.get_by_username(username) dbpuppet = DBPuppet.get_by_username(username)
+3 -1
View File
@@ -424,8 +424,10 @@ class User(AbstractUser, BaseUser):
if not username: if not username:
return None return None
username = username.lower()
for _, user in cls.by_tgid.items(): for _, user in cls.by_tgid.items():
if user.username and user.username.lower() == username.lower(): if user.username and user.username.lower() == username:
return user return user
puppet = DBUser.get_by_username(username) puppet = DBUser.get_by_username(username)