xmlstream.py 53.4 KB
Newer Older
Nathan Fritz's avatar
Nathan Fritz committed
1

2 3 4 5 6 7 8 9
# slixmpp.xmlstream.xmlstream
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module provides the module for creating and
# interacting with generic XML streams, along with
# the necessary eventing infrastructure.
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
10 11
from typing import (
    Any,
mathieui's avatar
mathieui committed
12 13 14
    Dict,
    Awaitable,
    Generator,
15
    Coroutine,
16
    Callable,
mathieui's avatar
mathieui committed
17
    Iterator,
18 19 20 21
    List,
    Optional,
    Set,
    Union,
mathieui's avatar
mathieui committed
22
    Tuple,
mathieui's avatar
mathieui committed
23 24 25 26
    TypeVar,
    NoReturn,
    Type,
    cast,
27
)
28

mathieui's avatar
mathieui committed
29
import asyncio
louiz’'s avatar
louiz’ committed
30
import functools
Nathan Fritz's avatar
Nathan Fritz committed
31
import logging
32
import socket as Socket
33
import ssl
34
import weakref
35
import uuid
Nathan Fritz's avatar
Nathan Fritz committed
36

37
from contextlib import contextmanager
38
import xml.etree.ElementTree as ET
mathieui's avatar
mathieui committed
39 40 41 42 43 44 45 46 47 48
from asyncio import (
    AbstractEventLoop,
    BaseTransport,
    Future,
    Task,
    TimerHandle,
    Transport,
    iscoroutinefunction,
    wait,
)
49

mathieui's avatar
mathieui committed
50 51
from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
52
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
louiz’'s avatar
louiz’ committed
53
from slixmpp.xmlstream.resolver import resolve, default_resolver
mathieui's avatar
mathieui committed
54 55 56
from slixmpp.xmlstream.handler.base import BaseHandler

T = TypeVar('T')
57

58
#: The time in seconds to wait before timing out waiting for response stanzas.
59
RESPONSE_TIMEOUT = 30
60

61
log = logging.getLogger(__name__)
mathieui's avatar
mathieui committed
62 63


mathieui's avatar
mathieui committed
64 65 66 67
class ContinueQueue(Exception):
    """
    Exception raised in the send queue to "continue" from within an inner loop
    """
68

mathieui's avatar
mathieui committed
69

louiz’'s avatar
louiz’ committed
70 71 72 73 74
class NotConnectedError(Exception):
    """
    Raised when we try to send something over the wire but we are not
    connected.
    """
Nathan Fritz's avatar
Nathan Fritz committed
75

mathieui's avatar
mathieui committed
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

_T = TypeVar('_T', str, ElementBase, StanzaBase)


SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]


Filter = Union[
    SyncFilter,
    AsyncFilter,
]

_FiltersDict = Dict[str, List[Filter]]

Handler = Callable[[Any], Union[
    Any,
    Coroutine[Any, Any, Any]
]]


97
class XMLStream(asyncio.BaseProtocol):
98 99 100
    """
    An XML stream connection manager and event dispatcher.

101 102 103
    The XMLStream class abstracts away the issues of establishing a
    connection with a server and sending and receiving XML "stanzas".
    A stanza is a complete XML element that is a direct child of a root
104 105 106 107 108
    document element. Two streams are used, one for each communication
    direction, over the same socket. Once the connection is closed, both
    streams should be complete and valid XML documents.

    Three types of events are provided to manage the stream:
109 110 111 112
        :Stream: Triggered based on received stanzas, similar in concept
                 to events in a SAX XML parser.
        :Custom: Triggered manually.
        :Scheduled: Triggered based on time delays.
113

114 115 116 117
    Typically, stanzas are first processed by a stream event handler which
    will then trigger custom events to continue further processing,
    especially since custom event handlers may run in individual threads.

118 119 120 121
    :param socket: Use an existing socket for the stream. Defaults to
                   ``None`` to generate a new socket.
    :param string host: The name of the target server.
    :param int port: The port to use for the connection. Defaults to 0.
122 123
    """

mathieui's avatar
mathieui committed
124
    transport: Optional[Transport]
louiz’'s avatar
louiz’ committed
125

mathieui's avatar
mathieui committed
126 127 128 129 130 131 132 133 134 135 136 137 138 139
    # The socket that is used internally by the transport object
    socket: Optional[ssl.SSLSocket]

    # The backoff of the connect routine (increases exponentially
    # after each failure)
    _connect_loop_wait: float

    parser: Optional[ET.XMLPullParser]
    xml_depth: int
    xml_root: Optional[ET.Element]

    force_starttls: Optional[bool]
    disable_starttls: Optional[bool]

mathieui's avatar
mathieui committed
140
    waiting_queue: asyncio.Queue
mathieui's avatar
mathieui committed
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

    # A dict of {name: handle}
    scheduled_events: Dict[str, TimerHandle]

    ssl_context: ssl.SSLContext

    # The event to trigger when the create_connection() succeeds. It can
    # be "connected" or "tls_success" depending on the step we are at.
    event_when_connected: str

    #: The list of accepted ciphers, in OpenSSL Format.
    #: It might be useful to override it for improved security
    #: over the python defaults.
    ciphers: Optional[str]

    #: Path to a file containing certificates for verifying the
    #: server SSL certificate. A non-``None`` value will trigger
    #: certificate checking.
    #:
    #: .. note::
    #:
    #:     On Mac OS X, certificates in the system keyring will
    #:     be consulted, even if they are not in the provided file.
    ca_certs: Optional[str]

    #: Path to a file containing a client certificate to use for
    #: authenticating via SASL EXTERNAL. If set, there must also
    #: be a corresponding `:attr:keyfile` value.
    certfile: Optional[str]

    #: Path to a file containing the private key for the selected
    #: client certificate to use for authenticating via SASL EXTERNAL.
    keyfile: Optional[str]

    # The asyncio event loop
    _loop: Optional[AbstractEventLoop]

    #: The default port to return when querying DNS records.
    default_port: int

    #: The domain to try when querying DNS records.
    default_domain: str

    #: The expected name of the server, for validation.
    _expected_server_name: str
    _service_name: str

    #: The desired, or actual, address of the connected server.
    address: Tuple[str, int]

    #: Enable connecting to the server directly over SSL, in
    #: particular when the service provides two ports: one for
    #: non-SSL traffic and another for SSL traffic.
    use_ssl: bool

    #: If set to ``True``, attempt to use IPv6.
    use_ipv6: bool
louiz’'s avatar
louiz’ committed
198

mathieui's avatar
mathieui committed
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    #: If set to ``True``, allow using the ``dnspython`` DNS library
    #: if available. If set to ``False``, the builtin DNS resolver
    #: will be used, even if ``dnspython`` is installed.
    use_aiodns: bool

    #: Use CDATA for escaping instead of XML entities. Defaults
    #: to ``False``.
    use_cdata: bool

    #: The default namespace of the stream content, not of the
    #: stream wrapper it
    default_ns: str

    default_lang: Optional[str]
    peer_default_lang: Optional[str]

    #: The namespace of the enveloping stream element.
    stream_ns: str

    #: The default opening tag for the stream element.
    stream_header: str

    #: The default closing tag for the stream element.
    stream_footer: str

    #: If ``True``, periodically send a whitespace character over the
    #: wire to keep the connection alive. Mainly useful for connections
    #: traversing NAT.
    whitespace_keepalive: bool

    #: The default interval between keepalive signals when
    #: :attr:`whitespace_keepalive` is enabled.
    whitespace_keepalive_interval: int

    #: Flag for controlling if the session can be considered ended
    #: if the connection is terminated.
    end_session_on_disconnect: bool

    #: A mapping of XML namespaces to well-known prefixes.
    namespace_map: dict

    __root_stanza: List[Type[StanzaBase]]
    __handlers: List[BaseHandler]
    __event_handlers: Dict[str, List[Tuple[Handler, bool]]]
    __filters: _FiltersDict

    # Current connection attempt (Future)
mathieui's avatar
mathieui committed
246
    _current_connection_attempt: Optional[Future]
mathieui's avatar
mathieui committed
247 248 249 250 251 252 253 254 255 256 257 258 259

    #: A list of DNS results that have not yet been tried.
    _dns_answers: Optional[Iterator[Tuple[str, str, int]]]

    #: The service name to check with DNS SRV records. For
    #: example, setting this to ``'xmpp-client'`` would query the
    #: ``_xmpp-client._tcp`` service.
    dns_service: Optional[str]

    #: The reason why we are disconnecting from the server
    disconnect_reason: Optional[str]

    #: An asyncio Future being done when the stream is disconnected.
mathieui's avatar
mathieui committed
260
    disconnected: Future
mathieui's avatar
mathieui committed
261 262 263 264 265 266 267 268 269 270 271 272 273

    # If the session has been started or not
    _session_started: bool
    # If we want to bypass the send() check (e.g. unit tests)
    _always_send_everything: bool

    _run_out_filters: Optional[Future]
    __slow_tasks: List[Task]
    __queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]

    def __init__(self, host: str = '', port: int = 0):
        self.transport = None
        self.socket = None
274
        self._connect_loop_wait = 0
275

louiz’'s avatar
louiz’ committed
276 277 278 279 280 281 282
        self.parser = None
        self.xml_depth = 0
        self.xml_root = None

        self.force_starttls = None
        self.disable_starttls = None

283 284
        self.waiting_queue = asyncio.Queue()

louiz’'s avatar
louiz’ committed
285 286 287 288 289 290
        # A dict of {name: handle}
        self.scheduled_events = {}

        self.ssl_context = ssl.create_default_context()
        self.ssl_context.check_hostname = False
        self.ssl_context.verify_mode = ssl.CERT_NONE
291

292 293
        self.event_when_connected = "connected"

294
        self.ciphers = None
295

296
        self.ca_certs = None
297

298 299
        self.keyfile = None

300 301
        self._loop = None

302 303
        self.default_port = int(port)
        self.default_domain = ''
Lance Stout's avatar
Lance Stout committed
304 305

        self._expected_server_name = ''
306
        self._service_name = ''
Lance Stout's avatar
Lance Stout committed
307

308
        self.address = (host, int(port))
Lance Stout's avatar
Lance Stout committed
309

310
        self.use_ssl = False
311 312
        self.use_ipv6 = True

313
        self.use_aiodns = True
314 315
        self.use_cdata = False

316
        self.default_ns = ''
317

318 319 320
        self.default_lang = None
        self.peer_default_lang = None

321
        self.stream_ns = ''
322 323 324
        self.stream_header = "<stream>"
        self.stream_footer = "</stream>"

Lance Stout's avatar
Lance Stout committed
325 326 327
        self.whitespace_keepalive = True
        self.whitespace_keepalive_interval = 300

Lance Stout's avatar
Lance Stout committed
328
        self.end_session_on_disconnect = True
329
        self.namespace_map = {StanzaBase.xml_ns: 'xml'}
330 331 332

        self.__root_stanza = []
        self.__handlers = []
333
        self.__event_handlers = {}
mathieui's avatar
mathieui committed
334 335 336
        self.__filters = {
            'in': [], 'out': [], 'out_sync': []
        }
337

338 339
        self._current_connection_attempt = None

mathieui's avatar
mathieui committed
340
        self._dns_answers = None
Lance Stout's avatar
Lance Stout committed
341 342
        self.dns_service = None

343
        self.disconnect_reason = None
mathieui's avatar
mathieui committed
344
        self.disconnected = Future()
345 346 347
        self._session_started = False
        self._always_send_everything = False

348
        self.add_event_handler('disconnected', self._remove_schedules)
349
        self.add_event_handler('disconnected', self._set_disconnected)
Lance Stout's avatar
Lance Stout committed
350
        self.add_event_handler('session_start', self._start_keepalive)
351 352
        self.add_event_handler('session_start', self._set_session_start)
        self.add_event_handler('session_resumed', self._set_session_start)
353

mathieui's avatar
mathieui committed
354 355 356
        self._run_out_filters = None
        self.__slow_tasks = []
        self.__queued_stanzas = []
357

358
    @property
mathieui's avatar
mathieui committed
359
    def loop(self) -> AbstractEventLoop:
360 361 362 363 364
        if self._loop is None:
            self._loop = asyncio.get_event_loop()
        return self._loop

    @loop.setter
mathieui's avatar
mathieui committed
365
    def loop(self, value: AbstractEventLoop) -> None:
366 367
        self._loop = value

mathieui's avatar
mathieui committed
368
    def new_id(self) -> str:
369
        """Generate and return a new stream ID in hexadecimal form.
370

371 372
        Many stanzas, handlers, or matchers may require unique
        ID values. Using this method ensures that all new ID values
373 374
        are unique in this stream.
        """
375
        return uuid.uuid4().hex
376

mathieui's avatar
mathieui committed
377
    def _set_session_start(self, event: Any) -> None:
378 379 380 381 382 383
        """
        On session start, queue all pending stanzas to be sent.
        """
        self._session_started = True
        for stanza in self.__queued_stanzas:
            self.waiting_queue.put_nowait(stanza)
384
        self.__queued_stanzas = []
385

mathieui's avatar
mathieui committed
386
    def _set_disconnected(self, event: Any) -> None:
387 388
        self._session_started = False

mathieui's avatar
mathieui committed
389
    def _set_disconnected_future(self) -> None:
390 391 392 393 394
        """Set the self.disconnected future on disconnect"""
        if not self.disconnected.done():
            self.disconnected.set_result(True)
        self.disconnected = asyncio.Future()

mathieui's avatar
mathieui committed
395 396
    def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
                force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
397 398 399 400 401
        """Create a new socket and connect to the server.

        :param host: The name of the desired server for the connection.
        :param port: Port to connect to on the server.
        :param use_ssl: Flag indicating if SSL should be used by connecting
louiz’'s avatar
louiz’ committed
402 403 404 405 406 407
                        directly to a port using SSL.  If it is False, the
                        connection will be upgraded to SSL/TLS later, using
                        STARTTLS.  Only use this value for old servers that
                        have specific port for SSL/TLS
        :param force_starttls: If True, the connection will be aborted if
                               the server does not initiate a STARTTLS
Sam Whited's avatar
Sam Whited committed
408
                               negotiation.  If None, the connection will be
louiz’'s avatar
louiz’ committed
409
                               upgraded to TLS only if the server initiate
Sam Whited's avatar
Sam Whited committed
410
                               the STARTTLS negotiation, otherwise it will
louiz’'s avatar
louiz’ committed
411 412 413 414 415
                               connect in clear.  If False it will never
                               upgrade to TLS, even if the server provides
                               it.  Use this for example if you’re on
                               localhost

416
        """
mathieui's avatar
mathieui committed
417 418
        if self._run_out_filters is None or self._run_out_filters.done():
            self._run_out_filters = asyncio.ensure_future(
419 420 421
                self.run_filters(),
                loop=self.loop,
            )
422

423
        self.disconnect_reason = None
424
        self.cancel_connection_attempt()
425
        self._connect_loop_wait = 0
426 427
        if host and port:
            self.address = (host, int(port))
428 429
        try:
            Socket.inet_aton(self.address[0])
430
        except (Socket.error, ssl.SSLError):
431
            self.default_domain = self.address[0]
432

louiz’'s avatar
louiz’ committed
433
        # Respect previous TLS usage.
434 435
        if use_ssl is not None:
            self.use_ssl = use_ssl
louiz’'s avatar
louiz’ committed
436 437 438 439 440
        if force_starttls is not None:
            self.force_starttls = force_starttls
        if disable_starttls is not None:
            self.disable_starttls = disable_starttls

louiz’'s avatar
louiz’ committed
441
        self.event("connecting")
442 443 444 445
        self._current_connection_attempt = asyncio.ensure_future(
            self._connect_routine(),
            loop=self.loop,
        )
446

mathieui's avatar
mathieui committed
447
    async def _connect_routine(self) -> None:
448
        self.event_when_connected = "connected"
449

450 451 452
        if self._connect_loop_wait > 0:
            self.event('reconnect_delay', self._connect_loop_wait)
            await asyncio.sleep(self._connect_loop_wait, loop=self.loop)
453

mathieui's avatar
mathieui committed
454
        record = await self._pick_dns_answer(self.default_domain)
455
        if record is not None:
456 457
            host, address, dns_port = record
            port = dns_port if dns_port else self.address[1]
458 459
            self.address = (address, port)
            self._service_name = host
460
        else:
461 462
            # No DNS records left, stop iterating
            # and try (host, port) as a last resort
mathieui's avatar
mathieui committed
463
            self._dns_answers = None
464

mathieui's avatar
mathieui committed
465
        ssl_context: Optional[ssl.SSLContext]
466 467
        if self.use_ssl:
            ssl_context = self.get_ssl_context()
Mathias Ertl's avatar
Mathias Ertl committed
468 469
        else:
            ssl_context = None
470

471 472
        if self._current_connection_attempt is None:
            return
473
        try:
474
            await self.loop.create_connection(lambda: self,
475 476
                                                   self.address[0],
                                                   self.address[1],
477
                                                   ssl=ssl_context,
478
                                                   server_hostname=self.default_domain if self.use_ssl else None)
479
            self._connect_loop_wait = 0
480 481 482
        except Socket.gaierror as e:
            self.event('connection_failed',
                       'No DNS record available for %s' % self.default_domain)
483
        except OSError as e:
484
            log.debug('Connection failed: %s', e)
485
            self.event("connection_failed", e)
486 487
            if self._current_connection_attempt is None:
                return
488
            self._connect_loop_wait = self._connect_loop_wait * 2 + 1
489 490 491 492
            self._current_connection_attempt = asyncio.ensure_future(
                self._connect_routine(),
                loop=self.loop,
            )
louiz’'s avatar
louiz’ committed
493

mathieui's avatar
mathieui committed
494
    def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
495 496 497 498 499 500 501
        """Process all the available XMPP events (receiving or sending data on the
        socket(s), calling various registered callbacks, calling expired
        timers, handling signal events, etc).  If timeout is None, this
        function will run forever. If timeout is a number, this function
        will return after the given time in seconds.
        """
        if timeout is None:
502
            if forever:
503
                self.loop.run_forever()
504
            else:
505
                self.loop.run_until_complete(self.disconnected)
506
        else:
mathieui's avatar
mathieui committed
507
            tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)]
508 509
            if not forever:
                tasks.append(self.disconnected)
510
            self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
511

mathieui's avatar
mathieui committed
512
    def init_parser(self) -> None:
louiz’'s avatar
louiz’ committed
513 514 515
        """init the XML parser. The parser must always be reset for each new
        connexion
        """
louiz’'s avatar
louiz’ committed
516 517
        self.xml_depth = 0
        self.xml_root = None
518
        self.parser = ET.XMLPullParser(("start", "end"))
louiz’'s avatar
louiz’ committed
519

mathieui's avatar
mathieui committed
520
    def connection_made(self, transport: BaseTransport) -> None:
louiz’'s avatar
louiz’ committed
521 522
        """Called when the TCP connection has been established with the server
        """
523
        self.event(self.event_when_connected)
mathieui's avatar
mathieui committed
524 525 526
        self.transport = cast(Transport, transport)
        if self.transport is None:
            raise ValueError("Transport cannot be none")
mathieui's avatar
mathieui committed
527 528 529 530
        self.socket = self.transport.get_extra_info(
            "ssl_object",
            default=self.transport.get_extra_info("socket")
        )
531
        self._current_connection_attempt = None
louiz’'s avatar
louiz’ committed
532 533
        self.init_parser()
        self.send_raw(self.stream_header)
mathieui's avatar
mathieui committed
534
        self._dns_answers = None
louiz’'s avatar
louiz’ committed
535

mathieui's avatar
mathieui committed
536
    def data_received(self, data: bytes) -> None:
louiz’'s avatar
louiz’ committed
537 538 539 540 541 542
        """Called when incoming data is received on the socket.

        We feed that data to the parser and the see if this produced any XML
        event.  This could trigger one or more event (a stanza is received,
        the stream is opened, etc).
        """
543 544 545 546
        if self.parser is None:
            log.warning('Received data before the connection is established: %r',
                        data)
            return
louiz’'s avatar
louiz’ committed
547
        self.parser.feed(data)
548 549 550 551 552 553
        try:
            for event, xml in self.parser.read_events():
                if event == 'start':
                    if self.xml_depth == 0:
                        # We have received the start of the root element.
                        self.xml_root = xml
554 555 556 557 558
                        log.debug('RECV: %s', tostring(self.xml_root,
                                                       xmlns=self.default_ns,
                                                       stream=self,
                                                       top_level=True,
                                                       open_only=True))
559 560 561 562 563 564 565 566
                        self.start_stream_handler(self.xml_root)
                    self.xml_depth += 1
                if event == 'end':
                    self.xml_depth -= 1
                    if self.xml_depth == 0:
                        # The stream's root element has closed,
                        # terminating the stream.
                        log.debug("End of stream received")
567
                        self.disconnect_reason = "End of stream"
568
                        self.abort()
569
                        return
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
                    elif self.xml_depth == 1:
                        # A stanza is an XML element that is a direct child of
                        # the root element, hence the check of depth == 1
                        self._spawn_event(xml)
                        if self.xml_root is not None:
                            # Keep the root element empty of children to
                            # save on memory use.
                            self.xml_root.clear()
        except ET.ParseError:
            log.error('Parse error: %r', data)

            # Due to cyclic dependencies, this can’t be imported at the module
            # level.
            from slixmpp.stanza.stream_error import StreamError
            error = StreamError()
            error['condition'] = 'not-well-formed'
            error['text'] = 'Server sent: %r' % data
            self.send(error)
            self.disconnect()
589

mathieui's avatar
mathieui committed
590
    def is_connecting(self) -> bool:
591 592
        return self._current_connection_attempt is not None

mathieui's avatar
mathieui committed
593
    def is_connected(self) -> bool:
594 595
        return self.transport is not None

mathieui's avatar
mathieui committed
596
    def eof_received(self) -> None:
louiz’'s avatar
louiz’ committed
597
        """When the TCP connection is properly closed by the remote end
598
        """
599
        self.event("eof_received")
600

mathieui's avatar
mathieui committed
601
    def connection_lost(self, exception: Optional[BaseException]) -> None:
louiz’'s avatar
louiz’ committed
602 603
        """On any kind of disconnection, initiated by us or not.  This signals the
        closure of the TCP connection
604
        """
louiz’'s avatar
louiz’ committed
605 606 607
        log.info("connection_lost: %s", (exception,))
        # All these objects are associated with one TCP connection.  Since
        # we are not connected anymore, destroy them
louiz’'s avatar
louiz’ committed
608 609 610
        self.parser = None
        self.transport = None
        self.socket = None
611 612
        # Fire the events after cleanup
        if self.end_session_on_disconnect:
613
            self._reset_sendq()
614
            self.event('session_end')
615
        self._set_disconnected_future()
mathieui's avatar
mathieui committed
616
        self.event("disconnected", self.disconnect_reason or exception)
617

mathieui's avatar
mathieui committed
618
    def cancel_connection_attempt(self) -> None:
619
        """
Link Mauve's avatar
Link Mauve committed
620
        Immediately cancel the current create_connection() Future.
621 622 623 624 625 626 627 628
        This is useful when a client using slixmpp tries to connect
        on flaky networks, where sometimes a connection just gets lost
        and it needs to reconnect while the attempt is still ongoing.
        """
        if self._current_connection_attempt:
            self._current_connection_attempt.cancel()
            self._current_connection_attempt = None

mathieui's avatar
mathieui committed
629
    def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
630 631
        """Close the XML stream and wait for an acknowldgement from the server for
        at most `wait` seconds.  After the given number of seconds has
632
        passed without a response from the server, or when the server
Link Mauve's avatar
Link Mauve committed
633
        successfully responds with a closure of its own stream, abort() is
634 635
        called. If wait is 0.0, this will call abort() directly without closing
        the stream.
636

637
        Does nothing but trigger the disconnected event if we are not connected.
638 639

        :param wait: Time to wait for a response from the server.
640 641 642 643
        :param reason: An optional reason for the disconnect.
        :param ignore_send_queue: Boolean to toggle if we want to ignore
                                  the in-flight stanzas and disconnect immediately.
        :return: A future that ends when all code involved in the disconnect has ended
644
        """
645 646 647 648
        # Compat: docs/getting_started/sendlogout.rst has been promoting
        # `disconnect(wait=True)` for ages. This doesn't mean anything to the
        # schedule call below. It would fortunately be converted to `1` later
        # down the call chain. Praise the implicit casts lord.
mathieui's avatar
mathieui committed
649
        if wait is True:
650 651
            wait = 2.0

652
        if self.transport:
653
            self.disconnect_reason = reason
654 655
            if self.waiting_queue.empty() or ignore_send_queue:
                self.cancel_connection_attempt()
656 657 658 659
                return asyncio.ensure_future(
                    self._end_stream_wait(wait, reason=reason),
                    loop=self.loop,
                )
660
            else:
661
                return asyncio.ensure_future(
662 663 664
                    self._consume_send_queue_before_disconnecting(reason, wait),
                    loop=self.loop,
                )
665
        else:
666
            self._set_disconnected_future()
667
            self.event("disconnected", reason)
mathieui's avatar
mathieui committed
668
            future: Future = Future()
669 670
            future.set_result(None)
            return future
Nathan Fritz's avatar
Nathan Fritz committed
671

mathieui's avatar
mathieui committed
672
    async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
673
        """Wait until the send queue is empty before disconnecting"""
674 675 676 677 678 679 680
        try:
            await asyncio.wait_for(
                self.waiting_queue.join(),
                wait,
            )
        except asyncio.TimeoutError:
            wait = 0 # we already consumed the timeout
681
        self.disconnect_reason = reason
682 683
        await self._end_stream_wait(wait)

mathieui's avatar
mathieui committed
684
    async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
685 686 687 688 689 690 691
        """
        Run abort() if we do not received the disconnected event
        after a waiting time.

        :param wait: The waiting time (defaults to 2)
        """
        try:
692
            self.send_raw(self.stream_footer)
693 694 695 696 697 698 699
            await self.wait_until('disconnected', wait)
        except asyncio.TimeoutError:
            self.abort()
        except NotConnectedError:
            # We are not connected when sending the end of stream
            # that means the disconnect has already been handled
            pass
700

mathieui's avatar
mathieui committed
701
    def abort(self) -> None:
702 703 704 705
        """
        Forcibly close the connection
        """
        if self.transport:
706
            self.cancel_connection_attempt()
707
            self.transport.close()
708
            self.transport.abort()
louiz’'s avatar
louiz’ committed
709
            self.event("killed")
710

mathieui's avatar
mathieui committed
711
    def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
712 713 714
        """Calls disconnect(), and once we are disconnected (after the timeout, or
        when the server acknowledgement is received), call connect()
        """
715
        log.debug("reconnecting...")
mathieui's avatar
mathieui committed
716
        async def handler(event: Any) -> None:
717 718 719 720
            # We yield here to allow synchronous handlers to work first
            await asyncio.sleep(0, loop=self.loop)
            self.connect()
        self.add_event_handler('disconnected', handler, disposable=True)
721
        self.disconnect(wait, reason)
722

mathieui's avatar
mathieui committed
723
    def configure_socket(self) -> None:
724
        """Set timeout and other options for self.socket.
725 726 727

        Meant to be overridden.
        """
louiz’'s avatar
louiz’ committed
728
        pass
729

mathieui's avatar
mathieui committed
730
    def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
Lance Stout's avatar
Lance Stout committed
731
        """
732
        Configure and set options for a :class:`~dns.resolver.Resolver`
Lance Stout's avatar
Lance Stout committed
733
        instance, and other DNS related tasks. For example, you
Lance Stout's avatar
Lance Stout committed
734 735
        can also check :meth:`~socket.socket.getaddrinfo` to see
        if you need to call out to ``libresolv.so.2`` to
736
        run ``res_init()``.
Lance Stout's avatar
Lance Stout committed
737 738 739

        Meant to be overridden.

740 741 742 743
        :param resolver: A :class:`~dns.resolver.Resolver` instance
                         or ``None`` if ``dnspython`` is not installed.
        :param domain: The initial domain under consideration.
        :param port: The initial port under consideration.
Lance Stout's avatar
Lance Stout committed
744 745 746
        """
        pass

mathieui's avatar
mathieui committed
747
    def get_ssl_context(self) -> ssl.SSLContext:
748 749
        """
        Get SSL context.
750
        """
751 752
        if self.ciphers is not None:
            self.ssl_context.set_ciphers(self.ciphers)
753 754 755 756 757 758 759 760
        if self.keyfile and self.certfile:
            try:
                self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
            except (ssl.SSLError, OSError):
                log.debug('Error loading the cert chain:', exc_info=True)
            else:
                log.debug('Loaded cert file %s and key file %s',
                          self.certfile, self.keyfile)
761 762 763
        if self.ca_certs is not None:
            self.ssl_context.verify_mode = ssl.CERT_REQUIRED
            self.ssl_context.load_verify_locations(cafile=self.ca_certs)
764

765 766
        return self.ssl_context

mathieui's avatar
mathieui committed
767
    async def start_tls(self) -> bool:
768 769 770 771 772
        """Perform handshakes for TLS.

        If the handshake is successful, the XML stream will need
        to be restarted.
        """
mathieui's avatar
mathieui committed
773 774
        if self.transport is None:
            raise ValueError("Transport should not be None")
775 776
        self.event_when_connected = "tls_success"
        ssl_context = self.get_ssl_context()
mathieui's avatar
mathieui committed
777
        try:
778 779 780 781 782 783 784 785 786 787 788
            if hasattr(self.loop, 'start_tls'):
                transp = await self.loop.start_tls(self.transport,
                                                   self, ssl_context)
            # Python < 3.7
            else:
                transp, _ = await self.loop.create_connection(
                    lambda: self,
                    ssl=self.ssl_context,
                    sock=self.socket,
                    server_hostname=self.default_domain
                )
mathieui's avatar
mathieui committed
789 790 791 792 793
        except ssl.SSLError as e:
            log.debug('SSL: Unable to connect', exc_info=True)
            log.error('CERT: Invalid certificate trust chain.')
            if not self.event_handled('ssl_invalid_chain'):
                self.disconnect()
794
            else:
mathieui's avatar
mathieui committed
795 796
                self.event('ssl_invalid_chain', e)
            return False
797 798 799 800
        except OSError as exc:
            log.debug("Connection error:", exc_info=True)
            self.disconnect()
            return False
mathieui's avatar
mathieui committed
801 802 803
        der_cert = transp.get_extra_info("ssl_object").getpeercert(True)
        pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
        self.event('ssl_cert', pem_cert)
804 805 806 807
        # If we use the builtin start_tls, the connection_made() protocol
        # method is not called automatically
        if hasattr(self.loop, 'start_tls'):
            self.connection_made(transp)
mathieui's avatar
mathieui committed
808
        return True
809

mathieui's avatar
mathieui committed
810
    def _start_keepalive(self, event: Any) -> None:
811 812 813
        """Begin sending whitespace periodically to keep the connection alive.

        May be disabled by setting::
Lance Stout's avatar
Lance Stout committed
814 815 816

            self.whitespace_keepalive = False

817 818
        The keepalive interval can be set using::

Lance Stout's avatar
Lance Stout committed
819 820 821 822
            self.whitespace_keepalive_interval = 300
        """
        self.schedule('Whitespace Keepalive',
                      self.whitespace_keepalive_interval,
823
                      self.send_raw,
Lance Stout's avatar
Lance Stout committed
824
                      args=(' ',),
Lance Stout's avatar
Lance Stout committed
825 826
                      repeat=True)

mathieui's avatar
mathieui committed
827
    def _remove_schedules(self, event: Any) -> None:
828
        """Remove some schedules that become pointless when disconnected"""
louiz’'s avatar
louiz’ committed
829
        self.cancel_schedule('Whitespace Keepalive')
Lance Stout's avatar
Lance Stout committed
830

mathieui's avatar
mathieui committed
831
    def start_stream_handler(self, xml: ET.Element) -> None:
Lance Stout's avatar
Lance Stout committed
832
        """Perform any initialization actions, such as handshakes,
833
        once the stream header has been sent.
Lance Stout's avatar
Lance Stout committed
834 835 836

        Meant to be overridden.
        """
837 838
        pass

mathieui's avatar
mathieui committed
839
    def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
Lance Stout's avatar
Lance Stout committed
840 841
        """Add a stanza object class as a known root stanza.

842 843
        A root stanza is one that appears as a direct child of the stream's
        root element.
844 845

        Stanzas that appear as substanzas of a root stanza do not need to
Lance Stout's avatar
Lance Stout committed
846
        be registered here. That is done using register_stanza_plugin() from
louiz’'s avatar
louiz’ committed
847
        slixmpp.xmlstream.stanzabase.
848 849 850 851 852

        Stanzas that are not registered will not be converted into
        stanza objects, but may still be processed using handlers and
        matchers.

853
        :param stanza_class: The top-level stanza object's class.
854 855 856
        """
        self.__root_stanza.append(stanza_class)

mathieui's avatar
mathieui committed
857
    def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
Lance Stout's avatar
Lance Stout committed
858 859
        """Remove a stanza from being a known root stanza.

860 861
        A root stanza is one that appears as a direct child of the stream's
        root element.
862 863 864 865 866

        Stanzas that are not registered will not be converted into
        stanza objects, but may still be processed using handlers and
        matchers.
        """
Lance Stout's avatar
Lance Stout committed
867
        self.__root_stanza.remove(stanza_class)
868

mathieui's avatar
mathieui committed
869
    def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
870 871 872 873 874
        """Add a filter for incoming or outgoing stanzas.

        These filters are applied before incoming stanzas are
        passed to any handlers, and before outgoing stanzas
        are put in the send queue.
875 876 877 878 879 880 881 882 883 884

        Each filter must accept a single stanza, and return
        either a stanza or ``None``. If the filter returns
        ``None``, then the stanza will be dropped from being
        processed for events or from being sent.

        :param mode: One of ``'in'`` or ``'out'``.
        :param handler: The filter function.
        :param int order: The position to insert the filter in
                          the list of active filters.
885 886 887 888 889 890
        """
        if order:
            self.__filters[mode].insert(order, handler)
        else:
            self.__filters[mode].append(handler)

mathieui's avatar
mathieui committed
891
    def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
Lance Stout's avatar
Lance Stout committed
892 893 894
        """Remove an incoming or outgoing filter."""
        self.__filters[mode].remove(handler)

mathieui's avatar
mathieui committed
895
    def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
896
        """Add a stream event handler that will be executed when a matching
897 898
        stanza is received.

Lance Stout's avatar
Lance Stout committed
899
        :param handler:
louiz’'s avatar
louiz’ committed
900
                The :class:`~slixmpp.xmlstream.handler.base.BaseHandler`
Lance Stout's avatar
Lance Stout committed
901
                derived object to execute.
902 903 904
        """
        if handler.stream is None:
            self.__handlers.append(handler)
905
            handler.stream = weakref.ref(self)
906

mathieui's avatar
mathieui committed
907
    def remove_handler(self, name: str) -> bool:
908
        """Remove any stream event handlers with the given name.
909

910
        :param name: The name of the handler.
911 912 913 914 915 916 917 918 919
        """
        idx = 0
        for handler in self.__handlers:
            if handler.name == name:
                self.__handlers.pop(idx)
                return True
            idx += 1
        return False

mathieui's avatar
mathieui committed
920
    async def get_dns_records(self, domain: str, port: Optional[int] = None) -> List[Tuple[str, str, int]]:
921
        """Get the DNS records for a domain.
922

923 924
        :param domain: The domain in question.
        :param port: If the results don't include a port, use this one.
925
        """
926 927
        if port is None:
            port = self.default_port
Lance Stout's avatar
Lance Stout committed
928

929
        resolver = default_resolver(loop=self.loop)
Lance Stout's avatar
Lance Stout committed
930
        self.configure_dns(resolver, domain=domain, port=port)
931

932
        result = await resolve(domain, port,
933 934 935
                                    service=self.dns_service,
                                    resolver=resolver,
                                    use_ipv6=self.use_ipv6,
936 937
                                    use_aiodns=self.use_aiodns,
                                    loop=self.loop)
938
        return result
939

mathieui's avatar
mathieui committed
940
    async def _pick_dns_answer(self, domain: str, port: Optional[int] = None) -> Optional[Tuple[str, str, int]]:
941 942
        """Pick a server and port from DNS answers.

943 944 945
        Gets DNS answers if none available.
        Removes used answer from available answers.

946 947
        :param domain: The domain in question.
        :param port: If the results don't include a port, use this one.
948
        """
mathieui's avatar
mathieui committed
949
        if self._dns_answers is None:
950
            dns_records = await self.get_dns_records(domain, port)
mathieui's avatar
mathieui committed
951
            self._dns_answers = iter(dns_records)
952

mathieui's avatar
mathieui committed
953
        try:
mathieui's avatar
mathieui committed
954
            return next(self._dns_answers)
mathieui's avatar
mathieui committed
955
        except StopIteration:
mathieui's avatar
mathieui committed
956
            return None
Lance Stout's avatar
Lance Stout committed
957

mathieui's avatar
mathieui committed
958
    def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
959
        """Add a custom event handler that will be executed whenever
960 961
        its event is manually triggered.

962 963 964 965 966
        :param name: The name of the event that will trigger
                     this handler.
        :param pointer: The function to execute.
        :param disposable: If set to ``True``, the handler will be
                           discarded after one use. Defaults to ``False``.
967 968 969
        """
        if not name in self.__event_handlers:
            self.__event_handlers[name] = []
louiz’'s avatar
louiz’ committed