Add more type hints

This commit is contained in:
Tulir Asokan
2018-07-25 10:40:31 -04:00
parent ae334b9a04
commit dbfb980bde
20 changed files with 751 additions and 595 deletions
+22 -21
View File
@@ -14,6 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Any, Optional
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import random
@@ -24,28 +25,28 @@ yaml.indent(4)
class DictWithRecursion:
def __init__(self, data=None):
self._data = data or CommentedMap()
def __init__(self, data: CommentedMap = None):
self._data = data or CommentedMap() # type: CommentedMap
def _recursive_get(self, data, key, default_value):
def _recursive_get(self, data: CommentedMap, key: str, default_value: Any) -> Any:
if '.' in key:
key, next_key = key.split('.', 1)
next_data = data.get(key, CommentedMap())
return self._recursive_get(next_data, next_key, default_value)
return data.get(key, default_value)
def get(self, key, default_value, allow_recursion=True):
def get(self, key: str, default_value: Any, allow_recursion: bool = True) -> Any:
if allow_recursion and '.' in key:
return self._recursive_get(self._data, key, default_value)
return self._data.get(key, default_value)
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.get(key, None)
def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return self[key] is not None
def _recursive_set(self, data, key, value):
def _recursive_set(self, data: CommentedMap, key: str, value: Any):
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -55,16 +56,16 @@ class DictWithRecursion:
return
data[key] = value
def set(self, key, value, allow_recursion=True):
def set(self, key: str, value: Any, allow_recursion: bool = True):
if allow_recursion and '.' in key:
self._recursive_set(self._data, key, value)
return
self._data[key] = value
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any):
self.set(key, value)
def _recursive_del(self, data, key):
def _recursive_del(self, data: CommentedMap, key: str):
if '.' in key:
key, next_key = key.split('.', 1)
if key not in data:
@@ -78,7 +79,7 @@ class DictWithRecursion:
except KeyError:
pass
def delete(self, key, allow_recursion=True):
def delete(self, key: str, allow_recursion: bool = True):
if allow_recursion and '.' in key:
self._recursive_del(self._data, key)
return
@@ -88,23 +89,23 @@ class DictWithRecursion:
except KeyError:
pass
def __delitem__(self, key):
def __delitem__(self, key: str):
self.delete(key)
class Config(DictWithRecursion):
def __init__(self, path, registration_path, base_path):
def __init__(self, path: str, registration_path: str, base_path: str):
super().__init__()
self.path = path
self.registration_path = registration_path
self.base_path = base_path
self._registration = None
self.path = path # type: str
self.registration_path = registration_path # type: str
self.base_path = base_path # type: str
self._registration = None # type: dict
def load(self):
with open(self.path, 'r') as stream:
self._data = yaml.load(stream)
def load_base(self):
def load_base(self) -> Optional[DictWithRecursion]:
try:
with open(self.base_path, 'r') as stream:
return DictWithRecursion(yaml.load(stream))
@@ -120,7 +121,7 @@ class Config(DictWithRecursion):
yaml.dump(self._registration, stream)
@staticmethod
def _new_token():
def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def update(self):
@@ -246,7 +247,7 @@ class Config(DictWithRecursion):
self._data = base._data
self.save()
def _get_permissions(self, key):
def _get_permissions(self, key: str) -> Tuple[bool, bool, bool, bool, bool]:
level = self["bridge.permissions"].get(key, "")
admin = level == "admin"
puppeting = level == "full" or admin
@@ -254,7 +255,7 @@ class Config(DictWithRecursion):
relaybot = level == "relaybot" or user
return relaybot, user, puppeting, admin, level
def get_permissions(self, mxid):
def get_permissions(self, mxid: str) -> Tuple[bool, bool, bool, bool, bool]:
permissions = self["bridge.permissions"] or {}
if mxid in permissions:
return self._get_permissions(mxid)