diff options
Diffstat (limited to 'synapse/rest/client/v2_alpha/_base.py')
-rw-r--r-- | synapse/rest/client/v2_alpha/_base.py | 58 |
1 files changed, 31 insertions, 27 deletions
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index bc11b4dda4..f016b4f1bd 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,24 +17,32 @@ """ import logging import re - -from twisted.internet import defer +from typing import Iterable, Pattern from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.types import JsonDict logger = logging.getLogger(__name__) -def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): +def client_patterns( + path_regex: str, + releases: Iterable[int] = (0,), + unstable: bool = True, + v1: bool = False, +) -> Iterable[Pattern]: """Creates a regex compiled client path with the correct client path prefix. Args: - path_regex (str): The regex string to match. This should NOT have a ^ + path_regex: The regex string to match. This should NOT have a ^ as this will be prefixed. + releases: An iterable of releases to include this endpoint under. + unstable: If true, include this endpoint under the "unstable" prefix. + v1: If true, include this endpoint under the "api/v1" prefix. Returns: - SRE_Pattern + An iterable of patterns. """ patterns = [] @@ -51,7 +59,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): return patterns -def set_timeline_upper_limit(filter_json, filter_timeline_limit): +def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: + """ + Enforces a maximum limit of a timeline query. + + Params: + filter_json: The timeline query to modify. + filter_timeline_limit: The maximum limit to allow, passing -1 will + disable enforcing a maximum limit. + """ if filter_timeline_limit < 0: return # no upper limits timeline = filter_json.get("room", {}).get("timeline", {}) @@ -64,34 +80,22 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): def interactive_auth_handler(orig): """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors - Takes a on_POST method which returns a deferred (errcode, body) response + Takes a on_POST method which returns an Awaitable (errcode, body) response and adds exception handling to turn a InteractiveAuthIncompleteError into a 401 response. Normal usage is: @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): # ... - yield self.auth_handler.check_auth - """ + await self.auth_handler.check_auth + """ - def wrapped(*args, **kwargs): - res = defer.ensureDeferred(orig(*args, **kwargs)) - res.addErrback(_catch_incomplete_interactive_auth) - return res + async def wrapped(*args, **kwargs): + try: + return await orig(*args, **kwargs) + except InteractiveAuthIncompleteError as e: + return 401, e.result return wrapped - - -def _catch_incomplete_interactive_auth(f): - """helper for interactive_auth_handler - - Catches InteractiveAuthIncompleteErrors and turns them into 401 responses - - Args: - f (failure.Failure): - """ - f.trap(InteractiveAuthIncompleteError) - return 401, f.value.result |