xmlmask.py 3.9 KB
Newer Older
1

2 3 4 5
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010  Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
6 7
import logging

Nathan Fritz's avatar
Nathan Fritz committed
8 9
from xml.parsers.expat import ExpatError

louiz’'s avatar
louiz’ committed
10 11
from slixmpp.xmlstream.stanzabase import ET
from slixmpp.xmlstream.matcher.base import MatcherBase
12 13


14 15 16
log = logging.getLogger(__name__)


17 18 19 20 21 22 23
class MatchXMLMask(MatcherBase):

    """
    The XMLMask matcher selects stanzas whose XML matches a given
    XML pattern, or mask. For example, message stanzas with body elements
    could be matched using the mask:

24 25
    .. code-block:: xml

26 27
        <message xmlns="jabber:client"><body /></message>

Lance Stout's avatar
Lance Stout committed
28
    Use of XMLMask is discouraged, and
louiz’'s avatar
louiz’ committed
29 30
    :class:`~slixmpp.xmlstream.matcher.xpath.MatchXPath` or
    :class:`~slixmpp.xmlstream.matcher.stanzapath.StanzaPath`
31
    should be used instead.
32

33 34
    :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
                     object or XML string to use as a mask.
35 36
    """

37
    def __init__(self, criteria, default_ns='jabber:client'):
38 39 40
        MatcherBase.__init__(self, criteria)
        if isinstance(criteria, str):
            self._criteria = ET.fromstring(self._criteria)
41
        self.default_ns = default_ns
42 43

    def setDefaultNS(self, ns):
44
        """Set the default namespace to use during comparisons.
45

46
        :param ns: The new namespace to use as the default.
47 48 49 50
        """
        self.default_ns = ns

    def match(self, xml):
51
        """Compare a stanza object or XML object against the stored XML mask.
52 53 54

        Overrides MatcherBase.match.

55
        :param xml: The stanza object or XML object to compare against.
56 57 58 59 60 61
        """
        if hasattr(xml, 'xml'):
            xml = xml.xml
        return self._mask_cmp(xml, self._criteria, True)

    def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'):
62 63 64 65 66 67 68 69 70 71 72
        """Compare an XML object against an XML mask.

        :param source: The :class:`~xml.etree.ElementTree.Element` XML object
                       to compare against the mask.
        :param mask: The :class:`~xml.etree.ElementTree.Element` XML object
                     serving as the mask.
        :param use_ns: Indicates if namespaces should be respected during
                       the comparison.
        :default_ns: The default namespace to apply to elements that
                     do not have a specified namespace.
                     Defaults to ``"__no_ns__"``.
73 74
        """
        if source is None:
Link Mauve's avatar
Link Mauve committed
75
            # If the element was not found. May happen during recursive calls.
76 77 78 79 80 81 82
            return False

        # Convert the mask to an XML object if it is a string.
        if not hasattr(mask, 'attrib'):
            try:
                mask = ET.fromstring(mask)
            except ExpatError:
Lance Stout's avatar
Lance Stout committed
83
                log.warning("Expat error: %s\nIn parsing: %s", '', mask)
84 85 86 87

        mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
        if source.tag not in [mask.tag, mask_ns_tag]:
            return False
88 89

        # If the mask includes text, compare it.
90 91
        if mask.text and source.text and \
           source.text.strip() != mask.text.strip():
92 93 94 95 96 97 98 99 100
            return False

        # Compare attributes. The stanza must include the attributes
        # defined by the mask, but may include others.
        for name, value in mask.attrib.items():
            if source.attrib.get(name, "__None__") != value:
                return False

        # Recursively check subelements.
101
        matched_elements = {}
102
        for subelement in mask:
103 104 105 106 107 108 109 110 111
            matched = False
            for other in source.findall(subelement.tag):
                matched_elements[other] = False
                if self._mask_cmp(other, subelement, use_ns):
                    if not matched_elements.get(other, False):
                        matched_elements[other] = True
                        matched = True
            if not matched:
                return False
112 113 114

        # Everything matches.
        return True