summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7944.misc1
-rw-r--r--synapse/rest/client/v2_alpha/_base.py47
2 files changed, 22 insertions, 26 deletions
diff --git a/changelog.d/7944.misc b/changelog.d/7944.misc
new file mode 100644
index 0000000000..afbc91a494
--- /dev/null
+++ b/changelog.d/7944.misc
@@ -0,0 +1 @@
+Convert the interactive_auth_handler wrapper to async/await.
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index b21538766d..f016b4f1bd 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,8 +17,7 @@
 """
 import logging
 import re
-
-from twisted.internet import defer
+from typing import Iterable, Pattern
 
 from synapse.api.errors import InteractiveAuthIncompleteError
 from synapse.api.urls import CLIENT_API_PREFIX
@@ -27,15 +26,23 @@ from synapse.types import JsonDict
 logger = logging.getLogger(__name__)
 
 
-def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
+def client_patterns(
+    path_regex: str,
+    releases: Iterable[int] = (0,),
+    unstable: bool = True,
+    v1: bool = False,
+) -> Iterable[Pattern]:
     """Creates a regex compiled client path with the correct client path
     prefix.
 
     Args:
-        path_regex (str): The regex string to match. This should NOT have a ^
+        path_regex: The regex string to match. This should NOT have a ^
             as this will be prefixed.
+        releases: An iterable of releases to include this endpoint under.
+        unstable: If true, include this endpoint under the "unstable" prefix.
+        v1: If true, include this endpoint under the "api/v1" prefix.
     Returns:
-        SRE_Pattern
+        An iterable of patterns.
     """
     patterns = []
 
@@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
 def interactive_auth_handler(orig):
     """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
 
-    Takes a on_POST method which returns a deferred (errcode, body) response
+    Takes a on_POST method which returns an Awaitable (errcode, body) response
     and adds exception handling to turn a InteractiveAuthIncompleteError into
     a 401 response.
 
     Normal usage is:
 
     @interactive_auth_handler
-    @defer.inlineCallbacks
-    def on_POST(self, request):
+    async def on_POST(self, request):
         # ...
-        yield self.auth_handler.check_auth
-            """
+        await self.auth_handler.check_auth
+    """
 
-    def wrapped(*args, **kwargs):
-        res = defer.ensureDeferred(orig(*args, **kwargs))
-        res.addErrback(_catch_incomplete_interactive_auth)
-        return res
+    async def wrapped(*args, **kwargs):
+        try:
+            return await orig(*args, **kwargs)
+        except InteractiveAuthIncompleteError as e:
+            return 401, e.result
 
     return wrapped
-
-
-def _catch_incomplete_interactive_auth(f):
-    """helper for interactive_auth_handler
-
-    Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
-
-    Args:
-        f (failure.Failure):
-    """
-    f.trap(InteractiveAuthIncompleteError)
-    return 401, f.value.result