Commit 7c86c43f authored by mathieui's avatar mathieui

Merge branch 'mam-update' into 'master'

MAM Update

See merge request !149
parents 8a1f9bec 0115feaa
Pipeline #3803 passed with stages
in 3 minutes and 5 seconds
......@@ -92,6 +92,5 @@ Plugin index
xep_0428
xep_0437
xep_0439
xep_0441
xep_0444
......@@ -14,5 +14,6 @@ Stanza elements
.. automodule:: slixmpp.plugins.xep_0313.stanza
:members:
:member-order: bysource
:undoc-members:
XEP-0441: Message Archive Management Preferences
================================================
.. module:: slixmpp.plugins.xep_0441
.. autoclass:: XEP_0441
:members:
:exclude-members: session_bind, plugin_init, plugin_end
Stanza elements
---------------
.. automodule:: slixmpp.plugins.xep_0441.stanza
:members:
:undoc-members:
......@@ -22,11 +22,14 @@ class TestMAM(SlixIntegration):
"""Make sure we can get messages from our archive"""
# send messages first
tok = randint(1, 999999)
self.clients[0].make_message(mto=self.clients[1].boundjid, mbody='coucou').send()
self.clients[0].make_message(
mto=self.clients[1].boundjid,
mbody=f'coucou {tok}'
).send()
await self.clients[1].wait_until('message')
self.clients[1].make_message(
mto=self.clients[0].boundjid,
mbody='coucou coucou %s' % tok,
mbody=f'coucou coucou {tok}',
).send()
await self.clients[0].wait_until('message')
......@@ -48,8 +51,42 @@ class TestMAM(SlixIntegration):
if count >= 2:
break
self.assertEqual(msgs[0]['body'], 'coucou')
self.assertEqual(msgs[1]['body'], 'coucou coucou %s' % tok)
self.assertEqual(msgs[0]['body'], f'coucou {tok}')
self.assertEqual(msgs[1]['body'], f'coucou coucou {tok}')
async def test_mam_iterate(self):
"""Make sure we can iterate over messages from our archive"""
# send messages first
tok = randint(1, 999999)
self.clients[0].make_message(
mto=self.clients[1].boundjid,
mbody=f'coucou {tok}'
).send()
await self.clients[1].wait_until('message')
self.clients[1].make_message(
mto=self.clients[0].boundjid,
mbody='coucou coucou %s' % tok,
).send()
await self.clients[0].wait_until('message')
# Get archive
retrieve = self.clients[0]['xep_0313'].iterate(
with_jid=JID(self.envjid('CI_ACCOUNT2')),
reverse=True,
rsm={'max': 1}
)
msgs = []
count = 0
async for msg in retrieve:
msgs.append(
msg['mam_result']['forwarded']['stanza']
)
count += 1
if count >= 2:
break
self.assertEqual(msgs[0]['body'], f'coucou coucou {tok}')
self.assertEqual(msgs[1]['body'], f'coucou {tok}')
suite = unittest.TestLoader().loadTestsFromTestCase(TestMAM)
......@@ -110,5 +110,6 @@ __all__ = [
'xep_0428', # Message Fallback
'xep_0437', # Room Activity Indicators
'xep_0439', # Quick Response
'xep_0441', # Message Archive Management Preferences
'xep_0444', # Message Reactions
]
......@@ -135,6 +135,9 @@ class ResultIterator(AsyncIterator):
not r[self.recv_interface]['rsm']['last']:
raise StopAsyncIteration
if self.post_cb:
self.post_cb(r)
if r[self.recv_interface]['rsm']['count'] and \
r[self.recv_interface]['rsm']['first_index']:
count = int(r[self.recv_interface]['rsm']['count'])
......@@ -147,9 +150,6 @@ class ResultIterator(AsyncIterator):
self.start = r[self.recv_interface]['rsm']['first']
else:
self.start = r[self.recv_interface]['rsm']['last']
if self.post_cb:
self.post_cb(r)
return r
except XMPPError:
raise StopAsyncIteration
......
......@@ -5,8 +5,10 @@
# See the file LICENSE for copying permissio
from slixmpp.plugins.base import register_plugin
from slixmpp.plugins.xep_0313.stanza import Result, MAM, Preferences
from slixmpp.plugins.xep_0313.stanza import Result, MAM, Metadata
from slixmpp.plugins.xep_0313.mam import XEP_0313
register_plugin(XEP_0313)
__all__ = ['XEP_0313', 'Result', 'MAM', 'Metadata']
......@@ -5,8 +5,17 @@
# See the file LICENSE for copying permission
import logging
from asyncio import Future
from collections.abc import AsyncGenerator
from datetime import datetime
from typing import Any, Dict, Callable, Optional, Awaitable
from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Tuple,
)
from slixmpp import JID
from slixmpp.stanza import Message, Iq
......@@ -15,6 +24,7 @@ from slixmpp.xmlstream.matcher import MatchXMLMask
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins import BasePlugin
from slixmpp.plugins.xep_0313 import stanza
from slixmpp.plugins.xep_0004.stanza import Form
log = logging.getLogger(__name__)
......@@ -28,17 +38,25 @@ class XEP_0313(BasePlugin):
name = 'xep_0313'
description = 'XEP-0313: Message Archive Management'
dependencies = {'xep_0030', 'xep_0050', 'xep_0059', 'xep_0297'}
dependencies = {
'xep_0004', 'xep_0030', 'xep_0050', 'xep_0059', 'xep_0297'
}
stanza = stanza
def plugin_init(self):
register_stanza_plugin(stanza.MAM, Form)
register_stanza_plugin(Iq, stanza.MAM)
register_stanza_plugin(Iq, stanza.Preferences)
register_stanza_plugin(Message, stanza.Result)
register_stanza_plugin(Iq, stanza.Fin)
register_stanza_plugin(stanza.Result, self.xmpp['xep_0297'].stanza.Forwarded)
register_stanza_plugin(
stanza.Result,
self.xmpp['xep_0297'].stanza.Forwarded
)
register_stanza_plugin(stanza.MAM, self.xmpp['xep_0059'].stanza.Set)
register_stanza_plugin(stanza.Fin, self.xmpp['xep_0059'].stanza.Set)
register_stanza_plugin(Iq, stanza.Metadata)
register_stanza_plugin(stanza.Metadata, stanza.Start)
register_stanza_plugin(stanza.Metadata, stanza.End)
def retrieve(
self,
......@@ -66,16 +84,10 @@ class XEP_0313(BasePlugin):
:param bool iterator: Use RSM and iterate over a paginated query
:param dict rsm: RSM custom options
"""
iq = self.xmpp.Iq()
iq, stanza_mask = self._pre_mam_retrieve(
jid, start, end, with_jid, ifrom
)
query_id = iq['id']
iq['to'] = jid
iq['from'] = ifrom
iq['type'] = 'set'
iq['mam']['queryid'] = query_id
iq['mam']['start'] = start
iq['mam']['end'] = end
iq['mam']['with'] = with_jid
amount = 10
if rsm:
for key, value in rsm.items():
......@@ -84,12 +96,6 @@ class XEP_0313(BasePlugin):
amount = value
cb_data = {}
stanza_mask = self.xmpp.Message()
stanza_mask.xml.remove(stanza_mask.xml.find('{urn:xmpp:sid:0}origin-id'))
del stanza_mask['id']
del stanza_mask['lang']
stanza_mask['from'] = jid
stanza_mask['mam_result']['queryid'] = query_id
xml_mask = str(stanza_mask)
def pre_cb(query: Iq) -> None:
......@@ -106,11 +112,14 @@ class XEP_0313(BasePlugin):
results = cb_data['collector'].stop()
if result['type'] == 'result':
result['mam']['results'] = results
result['mam_fin']['results'] = results
if iterator:
return self.xmpp['xep_0059'].iterate(iq, 'mam', 'results', amount=amount,
reverse=reverse, recv_interface='mam_fin',
pre_cb=pre_cb, post_cb=post_cb)
return self.xmpp['xep_0059'].iterate(
iq, 'mam', 'results', amount=amount,
reverse=reverse, recv_interface='mam_fin',
pre_cb=pre_cb, post_cb=post_cb
)
collector = Collector(
'MAM_Results_%s' % query_id,
......@@ -126,26 +135,144 @@ class XEP_0313(BasePlugin):
return iq.send(timeout=timeout, callback=wrapped_cb)
def get_preferences(self, timeout=None, callback=None):
iq = self.xmpp.Iq()
iq['type'] = 'get'
async def iterate(
self,
jid: Optional[JID] = None,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
with_jid: Optional[JID] = None,
ifrom: Optional[JID] = None,
reverse: bool = False,
rsm: Optional[Dict[str, Any]] = None,
total: Optional[int] = None,
) -> AsyncGenerator:
"""
Iterate over each message of MAM query.
:param jid: Entity holding the MAM records
:param start: MAM query start time
:param end: MAM query end time
:param with_jid: Filter results on this JID
:param ifrom: To change the from address of the query
:param reverse: Get the results in reverse order
:param rsm: RSM custom options
:param total: A number of messages received after which the query
should stop.
"""
iq, stanza_mask = self._pre_mam_retrieve(
jid, start, end, with_jid, ifrom
)
query_id = iq['id']
iq['mam_prefs']['query_id'] = query_id
return iq.send(timeout=timeout, callback=callback)
def set_preferences(self, jid=None, default=None, always=None, never=None,
ifrom=None, timeout=None, callback=None):
iq = self.xmpp.Iq()
iq['type'] = 'set'
iq['to'] = jid
iq['from'] = ifrom
iq['mam_prefs']['default'] = default
iq['mam_prefs']['always'] = always
iq['mam_prefs']['never'] = never
return iq.send(timeout=timeout, callback=callback)
def get_configuration_commands(self, jid, **kwargs):
return self.xmpp['xep_0030'].get_items(
jid=jid,
node='urn:xmpp:mam#configure',
**kwargs)
amount = 10
if rsm:
for key, value in rsm.items():
iq['mam']['rsm'][key] = str(value)
if key == 'max':
amount = value
cb_data = {}
def pre_cb(query: Iq) -> None:
stanza_mask['mam_result']['queryid'] = query['id']
xml_mask = str(stanza_mask)
query['mam']['queryid'] = query['id']
collector = Collector(
'MAM_Results_%s' % query_id,
MatchXMLMask(xml_mask))
self.xmpp.register_handler(collector)
cb_data['collector'] = collector
def post_cb(result: Iq) -> None:
results = cb_data['collector'].stop()
if result['type'] == 'result':
result['mam']['results'] = results
result['mam_fin']['results'] = results
iterator = self.xmpp['xep_0059'].iterate(
iq, 'mam', 'results', amount=amount,
reverse=reverse, recv_interface='mam_fin',
pre_cb=pre_cb, post_cb=post_cb
)
recv_count = 0
async for page in iterator:
messages = [message for message in page['mam']['results']]
if reverse:
messages.reverse()
for message in messages:
yield message
recv_count += 1
if total is not None and recv_count >= total:
break
if total is not None and recv_count >= total:
break
def _pre_mam_retrieve(
self,
jid: Optional[JID] = None,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
with_jid: Optional[JID] = None,
ifrom: Optional[JID] = None,
) -> Tuple[Iq, Message]:
"""Build the IQ and stanza mask for MAM results
"""
iq = self.xmpp.make_iq_set(ito=jid, ifrom=ifrom)
query_id = iq['id']
iq['mam']['queryid'] = query_id
iq['mam']['start'] = start
iq['mam']['end'] = end
iq['mam']['with'] = with_jid
stanza_mask = self.xmpp.Message()
auto_origin = stanza_mask.xml.find('{urn:xmpp:sid:0}origin-id')
if auto_origin is not None:
stanza_mask.xml.remove(auto_origin)
del stanza_mask['id']
del stanza_mask['lang']
stanza_mask['from'] = jid
stanza_mask['mam_result']['queryid'] = query_id
return (iq, stanza_mask)
async def get_fields(self, jid: Optional[JID] = None, **iqkwargs) -> Form:
"""Get MAM query fields.
.. versionaddedd:: 1.8.0
:param jid: JID to retrieve the policy from.
:return: The Form of allowed options
"""
ifrom = iqkwargs.pop('ifrom', None)
iq = self.xmpp.make_iq_get(ito=jid, ifrom=ifrom)
iq.enable('mam')
result = await iq.send(**iqkwargs)
return result['mam']['form']
async def get_configuration_commands(self, jid: Optional[JID],
**discokwargs) -> Future:
"""Get the list of MAM advanced configuration commands.
.. versionchanged:: 1.8.0
:param jid: JID to get the commands from.
"""
if jid is None:
jid = self.xmpp.boundjid.bare
return await self.xmpp['xep_0030'].get_items(
jid=jid,
node='urn:xmpp:mam#configure',
**discokwargs
)
def get_archive_metadata(self, jid: Optional[JID] = None,
**iqkwargs) -> Future:
"""Get the archive metadata from a JID.
:param jid: JID to get the metadata from.
"""
ifrom = iqkwargs.pop('ifrom', None)
iq = self.xmpp.make_iq_get(ito=jid, ifrom=ifrom)
iq.enable('mam_metadata')
return iq.send(**iqkwargs)
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2012 Nathanael C. Fritz, Lance J.T. Stout
# This file is part of Slixmpp.
# See the file LICENSE for copying permissio
import datetime as dt
from datetime import datetime
from typing import (
Any,
Iterable,
List,
Optional,
Set,
Union,
)
from slixmpp.stanza import Message
from slixmpp.jid import JID
from slixmpp.xmlstream import ElementBase, ET
from slixmpp.plugins import xep_0082, xep_0004
from slixmpp.plugins import xep_0082
class MAM(ElementBase):
"""A MAM Query element.
.. code-block:: xml
<iq type='set' id='juliet1'>
<query xmlns='urn:xmpp:mam:2'>
<x xmlns='jabber:x:data' type='submit'>
<field var='FORM_TYPE' type='hidden'>
<value>urn:xmpp:mam:2</value>
</field>
<field var='with'>
<value>juliet@capulet.lit</value>
</field>
</x>
</query>
</iq>
"""
name = 'query'
namespace = 'urn:xmpp:mam:2'
plugin_attrib = 'mam'
interfaces = {'queryid', 'start', 'end', 'with', 'results'}
sub_interfaces = {'start', 'end', 'with'}
#: Available interfaces:
#:
#: - ``queryid``: The MAM query id
#: - ``start`` and ``end``: Temporal boundaries of the query
#: - ``with``: JID of the other entity the conversation is with
#: - ``after_id``: Fetch stanzas after this specific ID
#: - ``before_id``: Fetch stanzas before this specific ID
#: - ``ids``: Fetch the stanzas matching those IDs
#: - ``results``: pseudo-interface used to accumulate MAM results during
#: fetch, not relevant for the stanza itself.
interfaces = {
'queryid', 'start', 'end', 'with', 'results',
'before_id', 'after_id', 'ids',
}
sub_interfaces = {'start', 'end', 'with', 'before_id', 'after_id', 'ids'}
def setup(self, xml=None):
ElementBase.setup(self, xml)
self._form = xep_0004.stanza.Form()
self._form['type'] = 'submit'
field = self._form.add_field(var='FORM_TYPE', ftype='hidden',
value='urn:xmpp:mam:2')
self.append(self._form)
self._results = []
def __get_fields(self):
return self._form.get_fields()
def get_start(self):
fields = self.__get_fields()
self._results: List[Message] = []
def _setup_form(self):
found = self.xml.find(
'{jabber:x:data}x/'
'{jabber:x:data}field[@var="FORM_TYPE"]/'
"{jabber:x:data}value[.='urn:xmpp:mam:2']"
)
if found is None:
self['form']['type'] = 'submit'
self['form'].add_field(
var='FORM_TYPE', ftype='hidden', value='urn:xmpp:mam:2'
)
def get_fields(self):
form = self.get_plugin('form', check=True)
if not form:
return {}
return form.get_fields()
def get_start(self) -> Optional[datetime]:
fields = self.get_fields()
field = fields.get('start')
if field:
return xep_0082.parse(field['value'])
return None
def set_start(self, value):
if isinstance(value, dt.datetime):
def set_start(self, value: Union[str, datetime]):
self._setup_form()
if isinstance(value, datetime):
value = xep_0082.format_datetime(value)
fields = self.__get_fields()
field = fields.get('start')
if field:
field['value'] = value
else:
field = self._form.add_field(var='start')
field['value'] = value
self.set_custom_field('start', value)
def get_end(self):
fields = self.__get_fields()
def get_end(self) -> Optional[datetime]:
fields = self.get_fields()
field = fields.get('end')
if field:
return xep_0082.parse(field['value'])
return None
def set_end(self, value):
if isinstance(value, dt.datetime):
def set_end(self, value: Union[str, datetime]):
if isinstance(value, datetime):
value = xep_0082.format_datetime(value)
fields = self.__get_fields()
field = fields.get('end')
if field:
field['value'] = value
else:
field = self._form.add_field(var='end')
field['value'] = value
self.set_custom_field('end', value)
def get_with(self):
fields = self.__get_fields()
def get_with(self) -> Optional[JID]:
fields = self.get_fields()
field = fields.get('with')
if field:
return JID(field['value'])
return None
def set_with(self, value):
fields = self.__get_fields()
field = fields.get('with')
def set_with(self, value: JID):
self.set_custom_field('with', value)
def set_custom_field(self, fieldname: str, value: Any):
self._setup_form()
fields = self.get_fields()
field = fields.get(fieldname)
if field:
field['with'] = str(value)
field['value'] = str(value)
else:
field = self._form.add_field(var='with')
field = self['form'].add_field(var=fieldname)
field['value'] = str(value)
def get_custom_field(self, fieldname: str) -> Optional[str]:
fields = self.get_fields()
field = fields.get(fieldname)
if field:
return field['value']
return None
def set_before_id(self, value: str):
self.set_custom_field('before-id', value)
def get_before_id(self):
self.get_custom_field('before-id')
def set_after_id(self, value: str):
self.set_custom_field('after-id', value)
def get_after_id(self):
self.get_custom_field('after-id')
def set_ids(self, value: List[str]):
self._setup_form()
fields = self.get_fields()
field = fields.get('ids')
if field:
field['ids'] = value
else:
field = self['form'].add_field(var='ids')
field['value'] = value
def get_ids(self):
self.get_custom_field('id')
# The results interface is meant only as an easy
# way to access the set of collected message responses
# from the query.
def get_results(self):
def get_results(self) -> List[Message]:
return self._results
def set_results(self, values):
def set_results(self, values: List[Message]):
self._results = values