Commit 695b2ee0 authored by mathieui's avatar mathieui Committed by Link Mauve

decorators: make decorators work with coroutines

Tried the least ugly solution I could thing of.
parent f5ad5199
...@@ -3,11 +3,13 @@ Module containing various decorators ...@@ -3,11 +3,13 @@ Module containing various decorators
""" """
from __future__ import annotations from __future__ import annotations
from asyncio import iscoroutinefunction
from typing import ( from typing import (
cast, cast,
Any, Any,
Callable, Callable,
Dict,
List, List,
Optional, Optional,
TypeVar, TypeVar,
...@@ -21,6 +23,37 @@ if TYPE_CHECKING: ...@@ -21,6 +23,37 @@ if TYPE_CHECKING:
T = TypeVar('T', bound=Callable[..., Any]) T = TypeVar('T', bound=Callable[..., Any])
BeforeFunc = Callable[[List[Any], Dict[str, Any]], Any]
AfterFunc = Callable[[List[Any], Dict[str, Any]], Any]
def wrap_generic(func: Callable, before: BeforeFunc=None, after: AfterFunc=None):
"""
Generic wrapper which can both wrap coroutines and normal functions.
"""
def wrap(*args, **kwargs):
args = list(args)
if before is not None:
result = before(args, kwargs)
if result is not None:
return result
result = func(*args, **kwargs)
if after is not None:
result = after(result, args, kwargs)
return result
async def awrap(*args, **kwargs):
args = list(args)
if before is not None:
result = before(args, kwargs)
if result is not None:
return result
result = await func(*args, **kwargs)
if after is not None:
result = after(result, args, kwargs)
return result
if iscoroutinefunction(func):
return awrap
return wrap
class RefreshWrapper: class RefreshWrapper:
...@@ -32,12 +65,12 @@ class RefreshWrapper: ...@@ -32,12 +65,12 @@ class RefreshWrapper:
Decorator to refresh the UI if the wrapped function Decorator to refresh the UI if the wrapped function
returns True returns True
""" """
def after(result: Any, args, kwargs) -> Any:
def wrap(*args: Any, **kwargs: Any) -> Any: if self.core and result:
ret = func(*args, **kwargs)
if self.core and ret:
self.core.refresh_window() self.core.refresh_window()
return ret return result
wrap = wrap_generic(func, after=after)
return cast(T, wrap) return cast(T, wrap)
...@@ -45,13 +78,12 @@ class RefreshWrapper: ...@@ -45,13 +78,12 @@ class RefreshWrapper:
""" """
Decorator that refreshs the UI no matter what after the function Decorator that refreshs the UI no matter what after the function
""" """
def after(result: Any, args, kwargs) -> Any:
def wrap(*args: Any, **kwargs: Any) -> Any:
ret = func(*args, **kwargs)
if self.core: if self.core:
self.core.refresh_window() self.core.refresh_window()
return ret return result
wrap = wrap_generic(func, after=after)
return cast(T, wrap) return cast(T, wrap)
def update(self, func: T) -> T: def update(self, func: T) -> T:
...@@ -59,12 +91,11 @@ class RefreshWrapper: ...@@ -59,12 +91,11 @@ class RefreshWrapper:
Decorator that only updates the screen Decorator that only updates the screen
""" """
def wrap(*args: Any, **kwargs: Any) -> Any: def after(result: Any, args, kwargs) -> Any:
ret = func(*args, **kwargs)
if self.core: if self.core:
self.core.doupdate() self.core.doupdate()
return ret return result
wrap = wrap_generic(func, after=after)
return cast(T, wrap) return cast(T, wrap)
...@@ -82,21 +113,18 @@ class CommandArgParser: ...@@ -82,21 +113,18 @@ class CommandArgParser:
"""Just call the function with a single string, which is the original string """Just call the function with a single string, which is the original string
untouched untouched
""" """
return func
def wrap(self: Any, args: Any, *a: Any, **kw: Any) -> Any:
return func(self, args, *a, **kw)
return cast(T, wrap)
@staticmethod @staticmethod
def ignored(func: T) -> T: def ignored(func: T) -> T:
""" """
Call the function without any argument Call the function without textual arguments
""" """
def before(args: List[Any], kwargs: Dict[Any, Any]) -> None:
if len(args) >= 2:
del args[1]
def wrap(self: Any, args: Any = None, *a: Any, **kw: Any) -> Any: wrap = wrap_generic(func, before=before)
return func(self, *a, **kw)
return cast(T, wrap) return cast(T, wrap)
@staticmethod @staticmethod
...@@ -149,14 +177,16 @@ class CommandArgParser: ...@@ -149,14 +177,16 @@ class CommandArgParser:
default_args_outer = defaults or [] default_args_outer = defaults or []
def first(func: T) -> T: def first(func: T) -> T:
def second(self: Any, args: str, *a: Any, **kw: Any) -> Any: def before(args: List, kwargs: Dict[str, Any]) -> Any:
default_args = default_args_outer default_args = default_args_outer
if args and args.strip(): cmdargs = args[1]
split_args = common.shell_split(args) if cmdargs and cmdargs.strip():
split_args = common.shell_split(cmdargs)
else: else:
split_args = [] split_args = []
if len(split_args) < mandatory: if len(split_args) < mandatory:
return func(self, None, *a, **kw) args[1] = None
return
res, split_args = split_args[:mandatory], split_args[ res, split_args = split_args[:mandatory], split_args[
mandatory:] mandatory:]
if optional == -1: if optional == -1:
...@@ -171,22 +201,25 @@ class CommandArgParser: ...@@ -171,22 +201,25 @@ class CommandArgParser:
res += default_args res += default_args
if split_args and res and not ignore_trailing_arguments: if split_args and res and not ignore_trailing_arguments:
res[-1] += " " + " ".join(split_args) res[-1] += " " + " ".join(split_args)
return func(self, res, *a, **kw) args[1] = res
return
return cast(T, second) wrap = wrap_generic(func, before=before)
return cast(T, wrap)
return first return first
command_args_parser = CommandArgParser() command_args_parser = CommandArgParser()
def deny_anonymous(func: Callable) -> Callable: def deny_anonymous(func: Callable) -> Callable:
"""Decorator to disable commands when using an anonymous account.""" """Decorator to disable commands when using an anonymous account."""
def wrap(self: RosterInfoTab, *args: Any, **kwargs: Any) -> Any:
if self.core.xmpp.anon: def before(args: Any, kwargs: Any) -> Any:
return self.core.information( core = args[0].core
if core.xmpp.anon:
core.information(
'This command is not available for anonymous accounts.', 'This command is not available for anonymous accounts.',
'Info' 'Info'
) )
return func(self, *args, **kwargs) return False
wrap = wrap_generic(func, before=before)
return cast(T, wrap) return cast(T, wrap)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment