Commit fa063ddd authored by Maxime Buquet's avatar Maxime Buquet

Merge branch 'plugin-omemo' into 'master'

E2EE plugins support

See merge request !18
parents d7d4e30e abbb6a71
...@@ -240,6 +240,7 @@ class Core: ...@@ -240,6 +240,7 @@ class Core:
('groupchat_subject', self.handler.on_groupchat_subject), ('groupchat_subject', self.handler.on_groupchat_subject),
('http_confirm', self.handler.http_confirm), ('http_confirm', self.handler.http_confirm),
('message', self.handler.on_message), ('message', self.handler.on_message),
('message_encryption', self.handler.on_encrypted_message),
('message_error', self.handler.on_error_message), ('message_error', self.handler.on_error_message),
('message_xform', self.handler.on_data_form), ('message_xform', self.handler.on_data_form),
('no_auth', self.handler.on_no_auth), ('no_auth', self.handler.on_no_auth),
......
...@@ -271,6 +271,14 @@ class HandlerCore: ...@@ -271,6 +271,14 @@ class HandlerCore:
return return
self.on_normal_message(message) self.on_normal_message(message)
def on_encrypted_message(self, message):
"""
When receiving an encrypted message
"""
if message["body"]:
return # Already being handled by on_message.
self.on_message(message)
def on_error_message(self, message): def on_error_message(self, message):
""" """
When receiving any message with type="error" When receiving any message with type="error"
......
...@@ -75,9 +75,12 @@ class SafetyMetaclass(type): ...@@ -75,9 +75,12 @@ class SafetyMetaclass(type):
@staticmethod @staticmethod
def safe_func(f): def safe_func(f):
def helper(*args, **kwargs): def helper(*args, **kwargs):
passthrough = kwargs.pop('passthrough', False)
try: try:
return f(*args, **kwargs) return f(*args, **kwargs)
except: except:
if passthrough:
raise
if inspect.stack()[1][1] == inspect.getfile(f): if inspect.stack()[1][1] == inspect.getfile(f):
raise raise
elif SafetyMetaclass.core: elif SafetyMetaclass.core:
...@@ -86,9 +89,12 @@ class SafetyMetaclass(type): ...@@ -86,9 +89,12 @@ class SafetyMetaclass(type):
'Error') 'Error')
return None return None
async def async_helper(*args, **kwargs): async def async_helper(*args, **kwargs):
passthrough = kwargs.pop('passthrough', False)
try: try:
return await f(*args, **kwargs) return await f(*args, **kwargs)
except: except:
if passthrough:
raise
if inspect.stack()[1][1] == inspect.getfile(f): if inspect.stack()[1][1] == inspect.getfile(f):
raise raise
elif SafetyMetaclass.core: elif SafetyMetaclass.core:
......
...@@ -10,24 +10,24 @@ ...@@ -10,24 +10,24 @@
Interface for E2EE (End-to-end Encryption) plugins. Interface for E2EE (End-to-end Encryption) plugins.
""" """
from typing import ( from typing import Callable, Dict, List, Optional, Union, Tuple, Set
Callable,
Dict,
List,
Optional,
Union,
Tuple,
)
from slixmpp import InvalidJID, JID, Message from slixmpp import InvalidJID, JID, Message
from slixmpp.xmlstream import StanzaBase from slixmpp.xmlstream import StanzaBase
from poezio.tabs import ( from poezio.tabs import (
ChatTab,
ConversationTab, ConversationTab,
DynamicConversationTab, DynamicConversationTab,
PrivateTab,
MucTab, MucTab,
PrivateTab,
StaticConversationTab,
) )
from poezio.plugin import BasePlugin from poezio.plugin import BasePlugin
from poezio.theming import get_theme, dump_tuple
from poezio.config import config
from poezio.decorators import command_args_parser
from asyncio import iscoroutinefunction
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -36,6 +36,7 @@ log = logging.getLogger(__name__) ...@@ -36,6 +36,7 @@ log = logging.getLogger(__name__)
ChatTabs = Union[ ChatTabs = Union[
MucTab, MucTab,
DynamicConversationTab, DynamicConversationTab,
StaticConversationTab,
PrivateTab, PrivateTab,
] ]
...@@ -45,6 +46,12 @@ EME_TAG = 'encryption' ...@@ -45,6 +46,12 @@ EME_TAG = 'encryption'
JCLIENT_NS = 'jabber:client' JCLIENT_NS = 'jabber:client'
HINTS_NS = 'urn:xmpp:hints' HINTS_NS = 'urn:xmpp:hints'
class NothingToEncrypt(Exception):
"""
Exception to raise inside the _encrypt filter on stanzas that do not need
to be processed.
"""
class E2EEPlugin(BasePlugin): class E2EEPlugin(BasePlugin):
"""Interface for E2EE plugins. """Interface for E2EE plugins.
...@@ -72,7 +79,7 @@ class E2EEPlugin(BasePlugin): ...@@ -72,7 +79,7 @@ class E2EEPlugin(BasePlugin):
stanza_encryption = False stanza_encryption = False
#: Whitelist applied to messages when `stanza_encryption` is `False`. #: Whitelist applied to messages when `stanza_encryption` is `False`.
tag_whitelist = list(map(lambda x: '{%s}%s' % (x[0], x[1]), [ tag_whitelist = [
(JCLIENT_NS, 'body'), (JCLIENT_NS, 'body'),
(EME_NS, EME_TAG), (EME_NS, EME_TAG),
(HINTS_NS, 'store'), (HINTS_NS, 'store'),
...@@ -80,7 +87,7 @@ class E2EEPlugin(BasePlugin): ...@@ -80,7 +87,7 @@ class E2EEPlugin(BasePlugin):
(HINTS_NS, 'no-store'), (HINTS_NS, 'no-store'),
(HINTS_NS, 'no-permanent-store'), (HINTS_NS, 'no-permanent-store'),
# TODO: Add other encryption mechanisms tags here # TODO: Add other encryption mechanisms tags here
])) ]
#: Replaces body with `eme <https://xmpp.org/extensions/xep-0380.html>`_ #: Replaces body with `eme <https://xmpp.org/extensions/xep-0380.html>`_
#: if set. Should be suitable for most plugins except those using #: if set. Should be suitable for most plugins except those using
...@@ -109,7 +116,16 @@ class E2EEPlugin(BasePlugin): ...@@ -109,7 +116,16 @@ class E2EEPlugin(BasePlugin):
# time # time
_enabled_tabs = {} # type: Dict[JID, Callable] _enabled_tabs = {} # type: Dict[JID, Callable]
# Tabs that support this encryption mechanism
supported_tab_types = tuple() # type: Tuple[ChatTabs]
# States for each remote entity
trust_states = {'accepted': set(), 'rejected': set()} # type: Dict[str, Set[str]]
def init(self): def init(self):
self._all_trust_states = self.trust_states['accepted'].union(
self.trust_states['rejected']
)
if self.encryption_name is None and self.encryption_short_name is None: if self.encryption_name is None and self.encryption_short_name is None:
raise NotImplementedError raise NotImplementedError
...@@ -131,9 +147,9 @@ class E2EEPlugin(BasePlugin): ...@@ -131,9 +147,9 @@ class E2EEPlugin(BasePlugin):
# encrypted is encrypted, and no plain element slips in. # encrypted is encrypted, and no plain element slips in.
# Using a stream filter might be a bit too much, but at least we're # Using a stream filter might be a bit too much, but at least we're
# sure poezio is not sneaking anything past us. # sure poezio is not sneaking anything past us.
self.core.xmpp.add_filter('out', self._encrypt) self.core.xmpp.add_filter('out', self._encrypt_wrapper)
for tab_t in (DynamicConversationTab, PrivateTab, MucTab): for tab_t in self.supported_tab_types:
self.api.add_tab_command( self.api.add_tab_command(
tab_t, tab_t,
self.encryption_short_name, self.encryption_short_name,
...@@ -143,6 +159,33 @@ class E2EEPlugin(BasePlugin): ...@@ -143,6 +159,33 @@ class E2EEPlugin(BasePlugin):
help='Toggle automatic {} encryption for tab.'.format(self.encryption_name), help='Toggle automatic {} encryption for tab.'.format(self.encryption_name),
) )
trust_msg = 'Set {name} state to {state} for this fingerprint on this JID.'
for state in self._all_trust_states:
for tab_t in self.supported_tab_types:
self.api.add_tab_command(
tab_t,
self.encryption_short_name + '_' + state,
lambda args: self.__command_set_state_local(args, state),
usage='<fingerprint>',
short=trust_msg.format(name=self.encryption_short_name, state=state),
help=trust_msg.format(name=self.encryption_short_name, state=state),
)
self.api.add_command(
self.encryption_short_name + '_' + state,
lambda args: self.__command_set_state_global(args, state),
usage='<JID> <fingerprint>',
short=trust_msg.format(name=self.encryption_short_name, state=state),
help=trust_msg.format(name=self.encryption_short_name, state=state),
)
self.api.add_command(
self.encryption_short_name + '_fingerprint',
self._command_show_fingerprints,
usage='[jid]',
short='Show %s fingerprint(s) for a JID.' % self.encryption_short_name,
help='Show %s fingerprint(s) for a JID.' % self.encryption_short_name,
)
ConversationTab.add_information_element( ConversationTab.add_information_element(
self.encryption_short_name, self.encryption_short_name,
self._display_encryption_status, self._display_encryption_status,
...@@ -156,6 +199,15 @@ class E2EEPlugin(BasePlugin): ...@@ -156,6 +199,15 @@ class E2EEPlugin(BasePlugin):
self._display_encryption_status, self._display_encryption_status,
) )
self.__load_encrypted_states()
def __load_encrypted_states(self) -> None:
"""Load previously stored encryption states for jids."""
for section in config.sections():
value = config.get('encryption', section=section)
if value and value == self.encryption_short_name:
self._enabled_tabs[section] = self.encrypt
def cleanup(self): def cleanup(self):
ConversationTab.remove_information_element(self.encryption_short_name) ConversationTab.remove_information_element(self.encryption_short_name)
MucTab.remove_information_element(self.encryption_short_name) MucTab.remove_information_element(self.encryption_short_name)
...@@ -181,25 +233,120 @@ class E2EEPlugin(BasePlugin): ...@@ -181,25 +233,120 @@ class E2EEPlugin(BasePlugin):
if self._encryption_enabled(jid): if self._encryption_enabled(jid):
del self._enabled_tabs[jid] del self._enabled_tabs[jid]
config.remove_and_save('encryption', section=jid)
self.api.information( self.api.information(
'{} encryption disabled for {}'.format(self.encryption_name, jid), '{} encryption disabled for {}'.format(self.encryption_name, jid),
'Info', 'Info',
) )
else: else:
self._enabled_tabs[jid] = self.encrypt self._enabled_tabs[jid] = self.encrypt
config.set_and_save('encryption', self.encryption_short_name, section=jid)
self.api.information( self.api.information(
'{} encryption enabled for {}'.format(self.encryption_name, jid), '{} encryption enabled for {}'.format(self.encryption_name, jid),
'Info', 'Info',
) )
def _show_fingerprints(self, jid: JID) -> None:
"""Display encryption fingerprints for a JID."""
fprs = self.get_fingerprints(jid)
if len(fprs) == 1:
self.api.information(
'Fingerprint for %s: %s' % (jid, fprs[0]),
'Info',
)
elif fprs:
self.api.information(
'Fingerprints for %s:\n\t%s' % (jid, '\n\t'.join(fprs)),
'Info',
)
else:
self.api.information(
'No fingerprints to display',
'Info',
)
@command_args_parser.quoted(0, 1)
def _command_show_fingerprints(self, args: List[str]) -> None:
if not args and isinstance(self.api.current_tab(), self.supported_tab_types):
jid = self.api.current_tab().jid
elif args:
jid = args[0]
else:
self.api.information(
'%s_fingerprint: Couldn\'t deduce JID from context' % (
self.encryption_short_name),
'Error',
)
return None
self._show_fingerprints(JID(jid))
@command_args_parser.quoted(2)
def __command_set_state_global(self, args, state='') -> None:
jid, fpr = args
if state not in self._all_trust_states:
self.api.information(
'Unknown state for plugin %s: %s' % (
self.encryption_short_name, state),
'Error'
)
return
self.store_trust(jid, state, fpr)
@command_args_parser.quoted(1)
def __command_set_state_local(self, args, state='') -> None:
if isinstance(self.api.current_tab(), MucTab):
self.api.information(
'You can only trust each participant of a MUC individually.',
'Info',
)
return
jid = self.api.current_tab().jid
if not args:
self.api.information(
'No fingerprint provided to the command..',
'Error',
)
return
fpr = args[0]
if state not in self._all_trust_states:
self.api.information(
'Unknown state for plugin %s: %s' % (
self.encryption_short_name, state),
'Error',
)
return
self.store_trust(jid, state, fpr)
def _encryption_enabled(self, jid: JID) -> bool: def _encryption_enabled(self, jid: JID) -> bool:
return jid in self._enabled_tabs and self._enabled_tabs[jid] == self.encrypt return self._enabled_tabs.get(jid) == self.encrypt
async def _encrypt_wrapper(self, stanza: StanzaBase) -> Optional[StanzaBase]:
"""
Wrapper around _encrypt() to handle errors and display the message after encryption.
"""
try:
# pylint: disable=unexpected-keyword-arg
result = await self._encrypt(stanza, passthrough=True)
except NothingToEncrypt:
return stanza
except Exception as exc:
jid = stanza['to']
tab = self.core.tabs.by_name_and_class(jid, ChatTab)
msg = ' \n\x19%s}Could not send message: %s' % (
dump_tuple(get_theme().COLOR_CHAR_NACK),
exc,
)
tab.nack_message(msg, stanza['id'], stanza['from'])
# TODO: display exceptions to the user properly
log.error('Exception in encrypt:', exc_info=True)
return None
return result
def _decrypt(self, message: Message, tab: ChatTabs) -> None: def _decrypt(self, message: Message, tab: ChatTabs) -> None:
has_eme = False has_eme = False
if message.xml.find('{%s}%s' % (EME_NS, EME_TAG)) is not None and \ if message.xml.find('{%s}%s' % (EME_NS, EME_TAG)) is not None and \
message['eme']['namespace'] == self.eme_ns: message['eme']['namespace'] == self.eme_ns:
has_eme = True has_eme = True
has_encrypted_tag = False has_encrypted_tag = False
...@@ -219,15 +366,15 @@ class E2EEPlugin(BasePlugin): ...@@ -219,15 +366,15 @@ class E2EEPlugin(BasePlugin):
log.debug('Decrypted %s message: %r', self.encryption_name, message['body']) log.debug('Decrypted %s message: %r', self.encryption_name, message['body'])
return None return None
def _encrypt(self, stanza: StanzaBase) -> Optional[StanzaBase]: async def _encrypt(self, stanza: StanzaBase) -> Optional[StanzaBase]:
if not isinstance(stanza, Message) or stanza['type'] not in ('chat', 'groupchat'): if not isinstance(stanza, Message) or stanza['type'] not in ('chat', 'groupchat'):
return stanza raise NothingToEncrypt()
message = stanza message = stanza
tab = self.api.current_tab() jid = stanza['to']
jid = tab.jid tab = self.core.tabs.by_name_and_class(jid, ChatTab)
if not self._encryption_enabled(jid): if not self._encryption_enabled(jid):
return message raise NothingToEncrypt()
log.debug('Sending %s message: %r', self.encryption_name, message) log.debug('Sending %s message: %r', self.encryption_name, message)
...@@ -245,7 +392,11 @@ class E2EEPlugin(BasePlugin): ...@@ -245,7 +392,11 @@ class E2EEPlugin(BasePlugin):
return None return None
# Call the enabled encrypt method # Call the enabled encrypt method
self._enabled_tabs[jid](message, tab) func = self._enabled_tabs[jid]
if iscoroutinefunction(func):
await func(message, tab, passthrough=True)
else:
func(message, tab, passthrough=True)
if has_body: if has_body:
# Only add EME tag if the message has a body. # Only add EME tag if the message has a body.
...@@ -267,14 +418,26 @@ class E2EEPlugin(BasePlugin): ...@@ -267,14 +418,26 @@ class E2EEPlugin(BasePlugin):
if self.encrypted_tags is not None: if self.encrypted_tags is not None:
whitelist += self.encrypted_tags whitelist += self.encrypted_tags
tag_whitelist = {'{%s}%s' % tag for tag in whitelist}
for elem in message.xml[:]: for elem in message.xml[:]:
if elem.tag not in whitelist: if elem.tag not in tag_whitelist:
message.xml.remove(elem) message.xml.remove(elem)
log.debug('Encrypted %s message: %r', self.encryption_name, message) log.debug('Encrypted %s message: %r', self.encryption_name, message)
return message return message
def decrypt(self, _message: Message, tab: ChatTabs): def store_trust(self, jid: JID, state: str, fingerprint: str) -> None:
"""Store trust for a fingerprint and a jid."""
option_name = '%s:%s' % (self.encryption_short_name, fingerprint)
config.silent_set(option=option_name, value=state, section=jid)
def fetch_trust(self, jid: JID, fingerprint: str) -> str:
"""Fetch trust of a fingerprint and a jid."""
option_name = '%s:%s' % (self.encryption_short_name, fingerprint)
return config.get(option=option_name, section=jid)
async def decrypt(self, _message: Message, tab: ChatTabs):
"""Decryption method """Decryption method
This is a method the plugin must implement. It is expected that this This is a method the plugin must implement. It is expected that this
...@@ -288,7 +451,7 @@ class E2EEPlugin(BasePlugin): ...@@ -288,7 +451,7 @@ class E2EEPlugin(BasePlugin):
raise NotImplementedError raise NotImplementedError
def encrypt(self, _message: Message, tab: ChatTabs): async def encrypt(self, _message: Message, tab: ChatTabs):
"""Encryption method """Encryption method
This is a method the plugin must implement. It is expected that this This is a method the plugin must implement. It is expected that this
...@@ -301,3 +464,12 @@ class E2EEPlugin(BasePlugin): ...@@ -301,3 +464,12 @@ class E2EEPlugin(BasePlugin):
""" """
raise NotImplementedError raise NotImplementedError
def get_fingerprints(self, jid: JID) -> List[str]:
"""Show fingerprint(s) for this encryption method and JID.
To overload in plugins.
:returns: A list of fingerprints to display
"""
return []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment