Commit db48c8f4 authored by mathieui's avatar mathieui

xmlstream: add more types

parent c07476e7
......@@ -9,17 +9,24 @@
# :license: MIT, see LICENSE for more details
from typing import (
Any,
Dict,
Awaitable,
Generator,
Coroutine,
Callable,
Iterable,
Iterator,
List,
Optional,
Set,
Union,
Tuple,
TypeVar,
NoReturn,
Type,
cast,
)
import asyncio
import functools
import logging
import socket as Socket
......@@ -27,30 +34,66 @@ import ssl
import weakref
import uuid
import asyncio
from asyncio import iscoroutinefunction, wait, Future
from contextlib import contextmanager
import xml.etree.ElementTree as ET
from asyncio import (
AbstractEventLoop,
BaseTransport,
Future,
Task,
TimerHandle,
Transport,
iscoroutinefunction,
wait,
)
from slixmpp.xmlstream import tostring
from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
from slixmpp.xmlstream.resolver import resolve, default_resolver
from slixmpp.xmlstream.handler.base import BaseHandler
T = TypeVar('T')
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__)
class ContinueQueue(Exception):
"""
Exception raised in the send queue to "continue" from within an inner loop
"""
class NotConnectedError(Exception):
"""
Raised when we try to send something over the wire but we are not
connected.
"""
_T = TypeVar('_T', str, ElementBase, StanzaBase)
SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
Filter = Union[
SyncFilter,
AsyncFilter,
]
_FiltersDict = Dict[str, List[Filter]]
Handler = Callable[[Any], Union[
Any,
Coroutine[Any, Any, Any]
]]
class XMLStream(asyncio.BaseProtocol):
"""
An XML stream connection manager and event dispatcher.
......@@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
self.transport = None
transport: Optional[Transport]
# The socket that is used internally by the transport object
self.socket = None
# The socket that is used internally by the transport object
socket: Optional[ssl.SSLSocket]
# The backoff of the connect routine (increases exponentially
# after each failure)
_connect_loop_wait: float
parser: Optional[ET.XMLPullParser]
xml_depth: int
xml_root: Optional[ET.Element]
force_starttls: Optional[bool]
disable_starttls: Optional[bool]
waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]]
# A dict of {name: handle}
scheduled_events: Dict[str, TimerHandle]
ssl_context: ssl.SSLContext
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
event_when_connected: str
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
ciphers: Optional[str]
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
ca_certs: Optional[str]
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
certfile: Optional[str]
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
keyfile: Optional[str]
# The asyncio event loop
_loop: Optional[AbstractEventLoop]
#: The default port to return when querying DNS records.
default_port: int
#: The domain to try when querying DNS records.
default_domain: str
#: The expected name of the server, for validation.
_expected_server_name: str
_service_name: str
#: The desired, or actual, address of the connected server.
address: Tuple[str, int]
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
use_ssl: bool
#: If set to ``True``, attempt to use IPv6.
use_ipv6: bool
# The backoff of the connect routine (increases exponentially
# after each failure)
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
use_aiodns: bool
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
use_cdata: bool
#: The default namespace of the stream content, not of the
#: stream wrapper it
default_ns: str
default_lang: Optional[str]
peer_default_lang: Optional[str]
#: The namespace of the enveloping stream element.
stream_ns: str
#: The default opening tag for the stream element.
stream_header: str
#: The default closing tag for the stream element.
stream_footer: str
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
whitespace_keepalive: bool
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
whitespace_keepalive_interval: int
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
end_session_on_disconnect: bool
#: A mapping of XML namespaces to well-known prefixes.
namespace_map: dict
__root_stanza: List[Type[StanzaBase]]
__handlers: List[BaseHandler]
__event_handlers: Dict[str, List[Tuple[Handler, bool]]]
__filters: _FiltersDict
# Current connection attempt (Future)
_current_connection_attempt: Optional[Future[None]]
#: A list of DNS results that have not yet been tried.
_dns_answers: Optional[Iterator[Tuple[str, str, int]]]
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
dns_service: Optional[str]
#: The reason why we are disconnecting from the server
disconnect_reason: Optional[str]
#: An asyncio Future being done when the stream is disconnected.
disconnected: Future[bool]
# If the session has been started or not
_session_started: bool
# If we want to bypass the send() check (e.g. unit tests)
_always_send_everything: bool
_run_out_filters: Optional[Future]
__slow_tasks: List[Task]
__queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
def __init__(self, host: str = '', port: int = 0):
self.transport = None
self.socket = None
self._connect_loop_wait = 0
self.parser = None
......@@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
self.event_when_connected = "connected"
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
self.ciphers = None
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
self.ca_certs = None
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
self.certfile = None
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
self.keyfile = None
self._der_cert = None
# The asyncio event loop
self._loop = None
#: The default port to return when querying DNS records.
self.default_port = int(port)
#: The domain to try when querying DNS records.
self.default_domain = ''
#: The expected name of the server, for validation.
self._expected_server_name = ''
self._service_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
self.use_ssl = False
#: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
self.use_aiodns = True
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
self.use_cdata = False
#: The default namespace of the stream content, not of the
#: stream wrapper itself.
self.default_ns = ''
self.default_lang = None
self.peer_default_lang = None
#: The namespace of the enveloping stream element.
self.stream_ns = ''
#: The default opening tag for the stream element.
self.stream_header = "<stream>"
#: The default closing tag for the stream element.
self.stream_footer = "</stream>"
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
self.whitespace_keepalive = True
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
self.whitespace_keepalive_interval = 300
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
self.end_session_on_disconnect = True
#: A mapping of XML namespaces to well-known prefixes.
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
self.__root_stanza = []
self.__handlers = []
self.__event_handlers = {}
self.__filters = {'in': [], 'out': [], 'out_sync': []}
self.__filters = {
'in': [], 'out': [], 'out_sync': []
}
# Current connection attempt (Future)
self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried.
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
self._dns_answers = None
self.dns_service = None
#: The reason why we are disconnecting from the server
self.disconnect_reason = None
#: An asyncio Future being done when the stream is disconnected.
self.disconnected: Future = Future()
# If the session has been started or not
self.disconnected = Future()
self._session_started = False
# If we want to bypass the send() check (e.g. unit tests)
self._always_send_everything = False
self.add_event_handler('disconnected', self._remove_schedules)
......@@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._set_session_start)
self.add_event_handler('session_resumed', self._set_session_start)
self._run_out_filters: Optional[Future] = None
self.__slow_tasks: List[Future] = []
self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
self._run_out_filters = None
self.__slow_tasks = []
self.__queued_stanzas = []
@property
def loop(self):
def loop(self) -> AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
def loop(self, value):
def loop(self, value: AbstractEventLoop) -> None:
self._loop = value
def new_id(self):
def new_id(self) -> str:
"""Generate and return a new stream ID in hexadecimal form.
Many stanzas, handlers, or matchers may require unique
......@@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
def _set_session_start(self, event):
def _set_session_start(self, event: Any) -> None:
"""
On session start, queue all pending stanzas to be sent.
"""
......@@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
self.waiting_queue.put_nowait(stanza)
self.__queued_stanzas = []
def _set_disconnected(self, event):
def _set_disconnected(self, event: Any) -> None:
self._session_started = False
def _set_disconnected_future(self):
def _set_disconnected_future(self) -> None:
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False):
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
"""Create a new socket and connect to the server.
:param host: The name of the desired server for the connection.
......@@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
async def _connect_routine(self):
async def _connect_routine(self) -> None:
self.event_when_connected = "connected"
if self._connect_loop_wait > 0:
......@@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
# and try (host, port) as a last resort
self._dns_answers = None
ssl_context: Optional[ssl.SSLContext]
if self.use_ssl:
ssl_context = self.get_ssl_context()
else:
......@@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
def process(self, *, forever=True, timeout=None):
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
socket(s), calling various registered callbacks, calling expired
timers, handling signal events, etc). If timeout is None, this
......@@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.loop.run_until_complete(self.disconnected)
else:
tasks = [asyncio.sleep(timeout, loop=self.loop)]
tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)]
if not forever:
tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
def init_parser(self):
def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new
connexion
"""
......@@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_root = None
self.parser = ET.XMLPullParser(("start", "end"))
def connection_made(self, transport):
def connection_made(self, transport: BaseTransport) -> None:
"""Called when the TCP connection has been established with the server
"""
self.event(self.event_when_connected)
self.transport = transport
self.transport = cast(Transport, transport)
if self.transport is None:
raise ValueError("Transport cannot be none")
self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
......@@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(self.stream_header)
self._dns_answers = None
def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""Called when incoming data is received on the socket.
We feed that data to the parser and the see if this produced any XML
......@@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error)
self.disconnect()
def is_connecting(self):
def is_connecting(self) -> bool:
return self._current_connection_attempt is not None
def is_connected(self):
def is_connected(self) -> bool:
return self.transport is not None
def eof_received(self):
def eof_received(self) -> None:
"""When the TCP connection is properly closed by the remote end
"""
self.event("eof_received")
def connection_lost(self, exception):
def connection_lost(self, exception: Optional[BaseException]) -> None:
"""On any kind of disconnection, initiated by us or not. This signals the
closure of the TCP connection
"""
......@@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
self._reset_sendq()
self.event('session_end')
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
self.event("disconnected", self.disconnect_reason or exception)
def cancel_connection_attempt(self):
def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect
......@@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]:
"""Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server
......@@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
# schedule call below. It would fortunately be converted to `1` later
# down the call chain. Praise the implicit casts lord.
if wait == True:
if wait is True:
wait = 2.0
if self.transport:
......@@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self._set_disconnected_future()
self.event("disconnected", reason)
future = Future()
future: Future[None] = Future()
future.set_result(None)
return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
"""Wait until the send queue is empty before disconnecting"""
try:
await asyncio.wait_for(
......@@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = reason
await self._end_stream_wait(wait)
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
"""
Run abort() if we do not received the disconnected event
after a waiting time.
......@@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
# that means the disconnect has already been handled
pass
def abort(self):
def abort(self) -> None:
"""
Forcibly close the connection
"""
......@@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
self.transport.abort()
self.event("killed")
def reconnect(self, wait=2.0, reason="Reconnecting"):
def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
async def handler(event):
async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
def configure_socket(self):
def configure_socket(self) -> None:
"""Set timeout and other options for self.socket.
Meant to be overridden.
"""
pass
def configure_dns(self, resolver, domain=None, port=None):
def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
"""
Configure and set options for a :class:`~dns.resolver.Resolver`
instance, and other DNS related tasks. For example, you
......@@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
def get_ssl_context(self):
def get_ssl_context(self) -> ssl.SSLContext:
"""
Get SSL context.
"""
......@@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
async def start_tls(self):
async def start_tls(self) -> bool:
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
if self.transport is None:
raise ValueError("Transport should not be None")