xmlmask.py 3.78 KB
Newer Older
1 2 3 4
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010  Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
5 6
import logging

Nathan Fritz's avatar
Nathan Fritz committed
7
from xml.parsers.expat import ExpatError
mathieui's avatar
mathieui committed
8
from xml.etree.ElementTree import Element
Nathan Fritz's avatar
Nathan Fritz committed
9

mathieui's avatar
mathieui committed
10
from slixmpp.xmlstream.stanzabase import ET, StanzaBase
louiz’'s avatar
louiz’ committed
11
from slixmpp.xmlstream.matcher.base import MatcherBase
12 13


14 15 16
log = logging.getLogger(__name__)


17 18 19 20 21 22 23
class MatchXMLMask(MatcherBase):

    """
    The XMLMask matcher selects stanzas whose XML matches a given
    XML pattern, or mask. For example, message stanzas with body elements
    could be matched using the mask:

24 25
    .. code-block:: xml

26 27
        <message xmlns="jabber:client"><body /></message>

Lance Stout's avatar
Lance Stout committed
28
    Use of XMLMask is discouraged, and
louiz’'s avatar
louiz’ committed
29 30
    :class:`~slixmpp.xmlstream.matcher.xpath.MatchXPath` or
    :class:`~slixmpp.xmlstream.matcher.stanzapath.StanzaPath`
31
    should be used instead.
32

33 34
    :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
                     object or XML string to use as a mask.
35
    """
mathieui's avatar
mathieui committed
36
    _criteria: Element
37

mathieui's avatar
mathieui committed
38
    def __init__(self, criteria: str, default_ns: str = 'jabber:client'):
39 40
        MatcherBase.__init__(self, criteria)
        if isinstance(criteria, str):
mathieui's avatar
mathieui committed
41
            self._criteria = ET.fromstring(criteria)
42
        self.default_ns = default_ns
43

mathieui's avatar
mathieui committed
44
    def setDefaultNS(self, ns: str) -> None:
45
        """Set the default namespace to use during comparisons.
46

47
        :param ns: The new namespace to use as the default.
48 49 50
        """
        self.default_ns = ns

mathieui's avatar
mathieui committed
51
    def match(self, xml: StanzaBase) -> bool:
52
        """Compare a stanza object or XML object against the stored XML mask.
53 54 55

        Overrides MatcherBase.match.

56
        :param xml: The stanza object or XML object to compare against.
57
        """
mathieui's avatar
mathieui committed
58 59
        real_xml = xml.xml
        return self._mask_cmp(real_xml, self._criteria, True)
60

mathieui's avatar
mathieui committed
61 62
    def _mask_cmp(self, source: Element, mask: Element, use_ns: bool = False,
                  default_ns: str = '__no_ns__') -> bool:
63 64 65 66 67 68 69 70 71 72 73
        """Compare an XML object against an XML mask.

        :param source: The :class:`~xml.etree.ElementTree.Element` XML object
                       to compare against the mask.
        :param mask: The :class:`~xml.etree.ElementTree.Element` XML object
                     serving as the mask.
        :param use_ns: Indicates if namespaces should be respected during
                       the comparison.
        :default_ns: The default namespace to apply to elements that
                     do not have a specified namespace.
                     Defaults to ``"__no_ns__"``.
74 75
        """
        if source is None:
Link Mauve's avatar
Link Mauve committed
76
            # If the element was not found. May happen during recursive calls.
77 78
            return False

79 80 81
        mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
        if source.tag not in [mask.tag, mask_ns_tag]:
            return False
82 83

        # If the mask includes text, compare it.
84 85
        if mask.text and source.text and \
           source.text.strip() != mask.text.strip():
86 87 88 89 90 91 92 93 94
            return False

        # Compare attributes. The stanza must include the attributes
        # defined by the mask, but may include others.
        for name, value in mask.attrib.items():
            if source.attrib.get(name, "__None__") != value:
                return False

        # Recursively check subelements.
95
        matched_elements = {}
96
        for subelement in mask:
97 98 99 100 101 102 103 104 105
            matched = False
            for other in source.findall(subelement.tag):
                matched_elements[other] = False
                if self._mask_cmp(other, subelement, use_ns):
                    if not matched_elements.get(other, False):
                        matched_elements[other] = True
                        matched = True
            if not matched:
                return False
106 107 108

        # Everything matches.
        return True