Add more type hints
This commit is contained in:
+22
-21
@@ -14,6 +14,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Tuple, Any, Optional
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
import random
|
||||
@@ -24,28 +25,28 @@ yaml.indent(4)
|
||||
|
||||
|
||||
class DictWithRecursion:
|
||||
def __init__(self, data=None):
|
||||
self._data = data or CommentedMap()
|
||||
def __init__(self, data: CommentedMap = None):
|
||||
self._data = data or CommentedMap() # type: CommentedMap
|
||||
|
||||
def _recursive_get(self, data, key, default_value):
|
||||
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
|
||||
if '.' in key:
|
||||
key, next_key = key.split('.', 1)
|
||||
next_data = data.get(key, CommentedMap())
|
||||
return self._recursive_get(next_data, next_key, default_value)
|
||||
return data.get(key, default_value)
|
||||
|
||||
def get(self, key, default_value, allow_recursion=True):
|
||||
def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
|
||||
if allow_recursion and '.' in key:
|
||||
return self._recursive_get(self._data, key, default_value)
|
||||
return self._data.get(key, default_value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.get(key, None)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self[key] is not None
|
||||
|
||||
def _recursive_set(self, data, key, value):
|
||||
def _recursive_set(self, data: CommentedMap, key: str, value: Any):
|
||||
if '.' in key:
|
||||
key, next_key = key.split('.', 1)
|
||||
if key not in data:
|
||||
@@ -55,16 +56,16 @@ class DictWithRecursion:
|
||||
return
|
||||
data[key] = value
|
||||
|
||||
def set(self, key, value, allow_recursion=True):
|
||||
def set(self, key: str, value: Any, allow_recursion: bool = True):
|
||||
if allow_recursion and '.' in key:
|
||||
self._recursive_set(self._data, key, value)
|
||||
return
|
||||
self._data[key] = value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
self.set(key, value)
|
||||
|
||||
def _recursive_del(self, data, key):
|
||||
def _recursive_del(self, data: CommentedMap, key: str):
|
||||
if '.' in key:
|
||||
key, next_key = key.split('.', 1)
|
||||
if key not in data:
|
||||
@@ -78,7 +79,7 @@ class DictWithRecursion:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def delete(self, key, allow_recursion=True):
|
||||
def delete(self, key: str, allow_recursion: bool = True):
|
||||
if allow_recursion and '.' in key:
|
||||
self._recursive_del(self._data, key)
|
||||
return
|
||||
@@ -88,23 +89,23 @@ class DictWithRecursion:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key: str):
|
||||
self.delete(key)
|
||||
|
||||
|
||||
class Config(DictWithRecursion):
|
||||
def __init__(self, path, registration_path, base_path):
|
||||
def __init__(self, path: str, registration_path: str, base_path: str):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.registration_path = registration_path
|
||||
self.base_path = base_path
|
||||
self._registration = None
|
||||
self.path = path # type: str
|
||||
self.registration_path = registration_path # type: str
|
||||
self.base_path = base_path # type: str
|
||||
self._registration = None # type: dict
|
||||
|
||||
def load(self):
|
||||
with open(self.path, 'r') as stream:
|
||||
self._data = yaml.load(stream)
|
||||
|
||||
def load_base(self):
|
||||
def load_base(self) -> Optional[DictWithRecursion]:
|
||||
try:
|
||||
with open(self.base_path, 'r') as stream:
|
||||
return DictWithRecursion(yaml.load(stream))
|
||||
@@ -120,7 +121,7 @@ class Config(DictWithRecursion):
|
||||
yaml.dump(self._registration, stream)
|
||||
|
||||
@staticmethod
|
||||
def _new_token():
|
||||
def _new_token() -> str:
|
||||
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
||||
|
||||
def update(self):
|
||||
@@ -246,7 +247,7 @@ class Config(DictWithRecursion):
|
||||
self._data = base._data
|
||||
self.save()
|
||||
|
||||
def _get_permissions(self, key):
|
||||
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]:
|
||||
level = self["bridge.permissions"].get(key, "")
|
||||
admin = level == "admin"
|
||||
puppeting = level == "full" or admin
|
||||
@@ -254,7 +255,7 @@ class Config(DictWithRecursion):
|
||||
relaybot = level == "relaybot" or user
|
||||
return relaybot, user, puppeting, admin, level
|
||||
|
||||
def get_permissions(self, mxid):
|
||||
def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]:
|
||||
permissions = self["bridge.permissions"] or {}
|
||||
if mxid in permissions:
|
||||
return self._get_permissions(mxid)
|
||||
|
||||
Reference in New Issue
Block a user