Commit ef064299 authored by mathieui's avatar mathieui

slixmpp.util: type things

Fix a bug in the SASL implementation as well. (some special chars would
make things crash instead of being escaped)
parent b1411d8e
# Slixmpp: The Slick XMPP Library # Slixmpp: The Slick XMPP Library
# Copyright (C) 2018 Emmanuel Gil Peyrot # Copyright (C) 2018 Emmanuel Gil Peyrot
# This file is part of Slixmpp. # This file is part of Slixmpp.
...@@ -6,8 +5,11 @@ ...@@ -6,8 +5,11 @@
import os import os
import logging import logging
from typing import Callable, Optional, Any
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Cache: class Cache:
def retrieve(self, key): def retrieve(self, key):
raise NotImplementedError raise NotImplementedError
...@@ -16,7 +18,8 @@ class Cache: ...@@ -16,7 +18,8 @@ class Cache:
raise NotImplementedError raise NotImplementedError
def remove(self, key): def remove(self, key):
raise NotImplemented raise NotImplementedError
class PerJidCache: class PerJidCache:
def retrieve_by_jid(self, jid, key): def retrieve_by_jid(self, jid, key):
...@@ -28,6 +31,7 @@ class PerJidCache: ...@@ -28,6 +31,7 @@ class PerJidCache:
def remove_by_jid(self, jid, key): def remove_by_jid(self, jid, key):
raise NotImplementedError raise NotImplementedError
class MemoryCache(Cache): class MemoryCache(Cache):
def __init__(self): def __init__(self):
self.cache = {} self.cache = {}
...@@ -44,6 +48,7 @@ class MemoryCache(Cache): ...@@ -44,6 +48,7 @@ class MemoryCache(Cache):
del self.cache[key] del self.cache[key]
return True return True
class MemoryPerJidCache(PerJidCache): class MemoryPerJidCache(PerJidCache):
def __init__(self): def __init__(self):
self.cache = {} self.cache = {}
...@@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache): ...@@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache):
del cache[key] del cache[key]
return True return True
class FileSystemStorage: class FileSystemStorage:
def __init__(self, encode, decode, binary): def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool):
self.encode = encode if encode is not None else lambda x: x self.encode = encode if encode is not None else lambda x: x
self.decode = decode if decode is not None else lambda x: x self.decode = decode if decode is not None else lambda x: x
self.read = 'rb' if binary else 'r' self.read = 'rb' if binary else 'r'
self.write = 'wb' if binary else 'w' self.write = 'wb' if binary else 'w'
def _retrieve(self, directory, key): def _retrieve(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_')) filename = os.path.join(directory, key.replace('/', '_'))
try: try:
with open(filename, self.read) as cache_file: with open(filename, self.read) as cache_file:
...@@ -86,7 +92,7 @@ class FileSystemStorage: ...@@ -86,7 +92,7 @@ class FileSystemStorage:
log.debug('Removing %s entry', key) log.debug('Removing %s entry', key)
self._remove(directory, key) self._remove(directory, key)
def _store(self, directory, key, value): def _store(self, directory: str, key: str, value):
filename = os.path.join(directory, key.replace('/', '_')) filename = os.path.join(directory, key.replace('/', '_'))
try: try:
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
...@@ -99,7 +105,7 @@ class FileSystemStorage: ...@@ -99,7 +105,7 @@ class FileSystemStorage:
except Exception: except Exception:
log.debug('Failed to encode %s to cache:', key, exc_info=True) log.debug('Failed to encode %s to cache:', key, exc_info=True)
def _remove(self, directory, key): def _remove(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_')) filename = os.path.join(directory, key.replace('/', '_'))
try: try:
os.remove(filename) os.remove(filename)
...@@ -108,8 +114,9 @@ class FileSystemStorage: ...@@ -108,8 +114,9 @@ class FileSystemStorage:
return False return False
return True return True
class FileSystemCache(Cache, FileSystemStorage): class FileSystemCache(Cache, FileSystemStorage):
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary) FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type) self.base_dir = os.path.join(directory, cache_type)
...@@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage): ...@@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage):
def remove(self, key): def remove(self, key):
return self._remove(self.base_dir, key) return self._remove(self.base_dir, key)
class FileSystemPerJidCache(PerJidCache, FileSystemStorage): class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary) FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type) self.base_dir = os.path.join(directory, cache_type)
......
...@@ -2,15 +2,19 @@ import builtins ...@@ -2,15 +2,19 @@ import builtins
import sys import sys
import hashlib import hashlib
from typing import Optional, Union, Callable, List
def unicode(text): bytes_ = builtins.bytes # alias the stdlib type but ew
def unicode(text: Union[bytes_, str]) -> str:
if not isinstance(text, str): if not isinstance(text, str):
return text.decode('utf-8') return text.decode('utf-8')
else: else:
return text return text
def bytes(text): def bytes(text: Optional[Union[str, bytes_]]) -> bytes_:
""" """
Convert Unicode text to UTF-8 encoded bytes. Convert Unicode text to UTF-8 encoded bytes.
...@@ -34,7 +38,7 @@ def bytes(text): ...@@ -34,7 +38,7 @@ def bytes(text):
return builtins.bytes(text, encoding='utf-8') return builtins.bytes(text, encoding='utf-8')
def quote(text): def quote(text: Union[str, bytes_]) -> bytes_:
""" """
Enclose in quotes and escape internal slashes and double quotes. Enclose in quotes and escape internal slashes and double quotes.
...@@ -44,7 +48,7 @@ def quote(text): ...@@ -44,7 +48,7 @@ def quote(text):
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"' return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
def num_to_bytes(num): def num_to_bytes(num: int) -> bytes_:
""" """
Convert an integer into a four byte sequence. Convert an integer into a four byte sequence.
...@@ -58,21 +62,21 @@ def num_to_bytes(num): ...@@ -58,21 +62,21 @@ def num_to_bytes(num):
return bval return bval
def bytes_to_num(bval): def bytes_to_num(bval: bytes_) -> int:
""" """
Convert a four byte sequence to an integer. Convert a four byte sequence to an integer.
:param bytes bval: A four byte sequence to turn into an integer. :param bytes bval: A four byte sequence to turn into an integer.
""" """
num = 0 num = 0
num += ord(bval[0] << 24) num += (bval[0] << 24)
num += ord(bval[1] << 16) num += (bval[1] << 16)
num += ord(bval[2] << 8) num += (bval[2] << 8)
num += ord(bval[3]) num += (bval[3])
return num return num
def XOR(x, y): def XOR(x: bytes_, y: bytes_) -> bytes_:
""" """
Return the results of an XOR operation on two equal length byte strings. Return the results of an XOR operation on two equal length byte strings.
...@@ -85,7 +89,7 @@ def XOR(x, y): ...@@ -85,7 +89,7 @@ def XOR(x, y):
return builtins.bytes([a ^ b for a, b in zip(x, y)]) return builtins.bytes([a ^ b for a, b in zip(x, y)])
def hash(name): def hash(name: str) -> Optional[Callable]:
""" """
Return a hash function implementing the given algorithm. Return a hash function implementing the given algorithm.
...@@ -102,7 +106,7 @@ def hash(name): ...@@ -102,7 +106,7 @@ def hash(name):
return None return None
def hashes(): def hashes() -> List[str]:
""" """
Return a list of available hashing algorithms. Return a list of available hashing algorithms.
...@@ -115,28 +119,3 @@ def hashes(): ...@@ -115,28 +119,3 @@ def hashes():
t += ['MD2'] t += ['MD2']
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')] hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
return t + hashes return t + hashes
def setdefaultencoding(encoding):
"""
Set the current default string encoding used by the Unicode implementation.
Actually calls sys.setdefaultencoding under the hood - see the docs for that
for more details. This method exists only as a way to call find/call it
even after it has been 'deleted' when the site module is executed.
:param string encoding: An encoding name, compatible with sys.setdefaultencoding
"""
func = getattr(sys, 'setdefaultencoding', None)
if func is None:
import gc
import types
for obj in gc.get_objects():
if (isinstance(obj, types.BuiltinFunctionType)
and obj.__name__ == 'setdefaultencoding'):
func = obj
break
if func is None:
raise RuntimeError("Could not find setdefaultencoding")
sys.setdefaultencoding = func
return func(encoding)
# slixmpp.util.sasl.client # slixmpp.util.sasl.client
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module was originally based on Dave Cridland's Suelta library. # This module was originally based on Dave Cridland's Suelta library.
...@@ -6,9 +5,11 @@ ...@@ -6,9 +5,11 @@
# :copryight: (c) 2004-2013 David Alan Cridland # :copryight: (c) 2004-2013 David Alan Cridland
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout # :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
# :license: MIT, see LICENSE for more details # :license: MIT, see LICENSE for more details
from __future__ import annotations
import logging import logging
import stringprep import stringprep
from typing import Iterable, Set, Callable, Dict, Any, Optional, Type
from slixmpp.util import hashes, bytes, stringprep_profiles from slixmpp.util import hashes, bytes, stringprep_profiles
...@@ -16,11 +17,11 @@ log = logging.getLogger(__name__) ...@@ -16,11 +17,11 @@ log = logging.getLogger(__name__)
#: Global registry mapping mechanism names to implementation classes. #: Global registry mapping mechanism names to implementation classes.
MECHANISMS = {} MECHANISMS: Dict[str, Type[Mech]] = {}
#: Global registry mapping mechanism names to security scores. #: Global registry mapping mechanism names to security scores.
MECH_SEC_SCORES = {} MECH_SEC_SCORES: Dict[str, int] = {}
#: The SASLprep profile of stringprep used to validate simple username #: The SASLprep profile of stringprep used to validate simple username
...@@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create( ...@@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create(
unassigned=[stringprep.in_table_a1]) unassigned=[stringprep.in_table_a1])
def sasl_mech(score): def sasl_mech(score: int):
sec_score = score sec_score = score
def register(mech):
def register(mech: Type[Mech]):
n = 0 n = 0
mech.score = sec_score mech.score = sec_score
if mech.use_hashes: if mech.use_hashes:
...@@ -99,9 +101,9 @@ class Mech(object): ...@@ -99,9 +101,9 @@ class Mech(object):
score = -1 score = -1
use_hashes = False use_hashes = False
channel_binding = False channel_binding = False
required_credentials = set() required_credentials: Set[str] = set()
optional_credentials = set() optional_credentials: Set[str] = set()
security = set() security: Set[str] = set()
def __init__(self, name, credentials, security_settings): def __init__(self, name, credentials, security_settings):
self.credentials = credentials self.credentials = credentials
...@@ -118,7 +120,14 @@ class Mech(object): ...@@ -118,7 +120,14 @@ class Mech(object):
return b'' return b''
def choose(mech_list, credentials, security_settings, limit=None, min_mech=None): CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]]
SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]]
def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback,
security_settings: SecurityCallback,
limit: Optional[Iterable[Type[Mech]]] = None,
min_mech: Optional[str] = None) -> Mech:
available_mechs = set(MECHANISMS.keys()) available_mechs = set(MECHANISMS.keys())
if limit is None: if limit is None:
limit = set(mech_list) limit = set(mech_list)
...@@ -130,6 +139,9 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None) ...@@ -130,6 +139,9 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None)
mech_list = mech_list.intersection(limit) mech_list = mech_list.intersection(limit)
available_mechs = available_mechs.intersection(mech_list) available_mechs = available_mechs.intersection(mech_list)
if min_mech is None:
best_score = -1
else:
best_score = MECH_SEC_SCORES.get(min_mech, -1) best_score = MECH_SEC_SCORES.get(min_mech, -1)
best_mech = None best_mech = None
for name in available_mechs: for name in available_mechs:
......
...@@ -11,6 +11,9 @@ import hmac ...@@ -11,6 +11,9 @@ import hmac
import random import random
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
from typing import List, Dict, Optional
bytes_ = bytes
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
from slixmpp.util.sasl.client import sasl_mech, Mech, \ from slixmpp.util.sasl.client import sasl_mech, Mech, \
...@@ -63,7 +66,7 @@ class PLAIN(Mech): ...@@ -63,7 +66,7 @@ class PLAIN(Mech):
if not self.security_settings['encrypted_plain']: if not self.security_settings['encrypted_plain']:
raise SASLCancelled('PLAIN with encryption') raise SASLCancelled('PLAIN with encryption')
def process(self, challenge=b''): def process(self, challenge: bytes_ = b'') -> bytes_:
authzid = self.credentials['authzid'] authzid = self.credentials['authzid']
authcid = self.credentials['username'] authcid = self.credentials['username']
password = self.credentials['password'] password = self.credentials['password']
...@@ -148,7 +151,7 @@ class CRAM(Mech): ...@@ -148,7 +151,7 @@ class CRAM(Mech):
required_credentials = {'username', 'password'} required_credentials = {'username', 'password'}
security = {'encrypted', 'unencrypted_cram'} security = {'encrypted', 'unencrypted_cram'}
def setup(self, name): def setup(self, name: str):
self.hash_name = name[5:] self.hash_name = name[5:]
self.hash = hash(self.hash_name) self.hash = hash(self.hash_name)
if self.hash is None: if self.hash is None:
...@@ -157,14 +160,14 @@ class CRAM(Mech): ...@@ -157,14 +160,14 @@ class CRAM(Mech):
if not self.security_settings['unencrypted_cram']: if not self.security_settings['unencrypted_cram']:
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name) raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
def process(self, challenge=b''): def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge: if not challenge:
return None return None
username = self.credentials['username'] username = self.credentials['username']
password = self.credentials['password'] password = self.credentials['password']
mac = hmac.HMAC(key=password, digestmod=self.hash) mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
mac.update(challenge) mac.update(challenge)
return username + b' ' + bytes(mac.hexdigest()) return username + b' ' + bytes(mac.hexdigest())
...@@ -201,43 +204,42 @@ class SCRAM(Mech): ...@@ -201,43 +204,42 @@ class SCRAM(Mech):
def HMAC(self, key, msg): def HMAC(self, key, msg):
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest() return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
def Hi(self, text, salt, iterations): def Hi(self, text: str, salt: bytes_, iterations: int):
text = bytes(text) text_enc = bytes(text)
ui1 = self.HMAC(text, salt + b'\0\0\0\01') ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
ui = ui1 ui = ui1
for i in range(iterations - 1): for i in range(iterations - 1):
ui1 = self.HMAC(text, ui1) ui1 = self.HMAC(text_enc, ui1)
ui = XOR(ui, ui1) ui = XOR(ui, ui1)
return ui return ui
def H(self, text): def H(self, text: str) -> bytes_:
return self.hash(text).digest() return self.hash(text).digest()
def saslname(self, value): def saslname(self, value_b: bytes_) -> bytes_:
value = value.decode("utf-8") value = value_b.decode("utf-8")
escaped = [] escaped: List[str] = []
for char in value: for char in value:
if char == ',': if char == ',':
escaped += b'=2C' escaped.append('=2C')
elif char == '=': elif char == '=':
escaped += b'=3D' escaped.append('=3D')
else: else:
escaped += char escaped.append(char)
return "".join(escaped).encode("utf-8") return "".join(escaped).encode("utf-8")
def parse(self, challenge): def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
items = {} items = {}
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]: for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
items[key] = value items[key] = value
return items return items
def process(self, challenge=b''): def process(self, challenge: bytes_ = b''):
steps = [self.process_1, self.process_2, self.process_3] steps = [self.process_1, self.process_2, self.process_3]
return steps[self.step](challenge) return steps[self.step](challenge)
def process_1(self, challenge): def process_1(self, challenge: bytes_) -> bytes_:
self.step = 1 self.step = 1
data = {}
self.cnonce = bytes(('%s' % random.random())[2:]) self.cnonce = bytes(('%s' % random.random())[2:])
...@@ -263,7 +265,7 @@ class SCRAM(Mech): ...@@ -263,7 +265,7 @@ class SCRAM(Mech):
return self.client_first_message return self.client_first_message
def process_2(self, challenge): def process_2(self, challenge: bytes_) -> bytes_:
self.step = 2 self.step = 2
data = self.parse(challenge) data = self.parse(challenge)
...@@ -304,7 +306,7 @@ class SCRAM(Mech): ...@@ -304,7 +306,7 @@ class SCRAM(Mech):
return client_final_message return client_final_message
def process_3(self, challenge): def process_3(self, challenge: bytes_) -> bytes_:
data = self.parse(challenge) data = self.parse(challenge)
verifier = data.get(b'v', None) verifier = data.get(b'v', None)
error = data.get(b'e', 'Unknown error') error = data.get(b'e', 'Unknown error')
...@@ -345,17 +347,16 @@ class DIGEST(Mech): ...@@ -345,17 +347,16 @@ class DIGEST(Mech):
self.cnonce = b'' self.cnonce = b''
self.nonce_count = 1 self.nonce_count = 1
def parse(self, challenge=b''): def parse(self, challenge: bytes_ = b''):
data = {} data: Dict[str, bytes_] = {}
var_name = b'' var_name = b''
var_value = b'' var_value = b''
# States: var, new_var, end, quote, escaped_quote # States: var, new_var, end, quote, escaped_quote
state = 'var' state = 'var'
for char_int in challenge:
for char in challenge: char = bytes_([char_int])
char = bytes([char])
if state == 'var': if state == 'var':
if char.isspace(): if char.isspace():
...@@ -401,14 +402,14 @@ class DIGEST(Mech): ...@@ -401,14 +402,14 @@ class DIGEST(Mech):
state = 'var' state = 'var'
return data return data
def MAC(self, key, seq, msg): def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
mac = hmac.HMAC(key=key, digestmod=self.hash) mac = hmac.HMAC(key=key, digestmod=self.hash)