xmlmask.py 3.91 KB
Newer Older
Nathan Fritz's avatar
Nathan Fritz committed
1
"""
louiz’'s avatar
louiz’ committed
2
    Slixmpp: The Slick XMPP Library
Nathan Fritz's avatar
Nathan Fritz committed
3
    Copyright (C) 2010  Nathanael C. Fritz
louiz’'s avatar
louiz’ committed
4
    This file is part of Slixmpp.
Nathan Fritz's avatar
Nathan Fritz committed
5

6
    See the file LICENSE for copying permission.
Nathan Fritz's avatar
Nathan Fritz committed
7
"""
8

9 10
import logging

Nathan Fritz's avatar
Nathan Fritz committed
11 12
from xml.parsers.expat import ExpatError

louiz’'s avatar
louiz’ committed
13 14
from slixmpp.xmlstream.stanzabase import ET
from slixmpp.xmlstream.matcher.base import MatcherBase
15 16


17 18 19
log = logging.getLogger(__name__)


20 21 22 23 24 25 26
class MatchXMLMask(MatcherBase):

    """
    The XMLMask matcher selects stanzas whose XML matches a given
    XML pattern, or mask. For example, message stanzas with body elements
    could be matched using the mask:

27 28
    .. code-block:: xml

29 30
        <message xmlns="jabber:client"><body /></message>

Lance Stout's avatar
Lance Stout committed
31
    Use of XMLMask is discouraged, and
louiz’'s avatar
louiz’ committed
32 33
    :class:`~slixmpp.xmlstream.matcher.xpath.MatchXPath` or
    :class:`~slixmpp.xmlstream.matcher.stanzapath.StanzaPath`
34
    should be used instead.
35

36 37
    :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
                     object or XML string to use as a mask.
38 39
    """

40
    def __init__(self, criteria, default_ns='jabber:client'):
41 42 43
        MatcherBase.__init__(self, criteria)
        if isinstance(criteria, str):
            self._criteria = ET.fromstring(self._criteria)
44
        self.default_ns = default_ns
45 46

    def setDefaultNS(self, ns):
47
        """Set the default namespace to use during comparisons.
48

49
        :param ns: The new namespace to use as the default.
50 51 52 53
        """
        self.default_ns = ns

    def match(self, xml):
54
        """Compare a stanza object or XML object against the stored XML mask.
55 56 57

        Overrides MatcherBase.match.

58
        :param xml: The stanza object or XML object to compare against.
59 60 61 62 63 64
        """
        if hasattr(xml, 'xml'):
            xml = xml.xml
        return self._mask_cmp(xml, self._criteria, True)

    def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'):
65 66 67 68 69 70 71 72 73 74 75
        """Compare an XML object against an XML mask.

        :param source: The :class:`~xml.etree.ElementTree.Element` XML object
                       to compare against the mask.
        :param mask: The :class:`~xml.etree.ElementTree.Element` XML object
                     serving as the mask.
        :param use_ns: Indicates if namespaces should be respected during
                       the comparison.
        :default_ns: The default namespace to apply to elements that
                     do not have a specified namespace.
                     Defaults to ``"__no_ns__"``.
76 77 78 79 80 81 82 83 84 85
        """
        if source is None:
            # If the element was not found. May happend during recursive calls.
            return False

        # Convert the mask to an XML object if it is a string.
        if not hasattr(mask, 'attrib'):
            try:
                mask = ET.fromstring(mask)
            except ExpatError:
Lance Stout's avatar
Lance Stout committed
86
                log.warning("Expat error: %s\nIn parsing: %s", '', mask)
87 88 89 90

        mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
        if source.tag not in [mask.tag, mask_ns_tag]:
            return False
91 92

        # If the mask includes text, compare it.
93 94
        if mask.text and source.text and \
           source.text.strip() != mask.text.strip():
95 96 97 98 99 100 101 102 103
            return False

        # Compare attributes. The stanza must include the attributes
        # defined by the mask, but may include others.
        for name, value in mask.attrib.items():
            if source.attrib.get(name, "__None__") != value:
                return False

        # Recursively check subelements.
104
        matched_elements = {}
105
        for subelement in mask:
106 107 108 109 110 111 112 113 114
            matched = False
            for other in source.findall(subelement.tag):
                matched_elements[other] = False
                if self._mask_cmp(other, subelement, use_ns):
                    if not matched_elements.get(other, False):
                        matched_elements[other] = True
                        matched = True
            if not matched:
                return False
115 116 117

        # Everything matches.
        return True