summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-10-11 15:26:09 +0100
committerErik Johnston <erik@matrix.org>2019-10-11 15:26:09 +0100
commit3c2d6c708cd93df7fc945e10014049e9f9b36f46 (patch)
treefe45c9c825d9db33ec4fd50f57f6cf1b0a02f6e1
parentNewsfile (diff)
downloadsynapse-3c2d6c708cd93df7fc945e10014049e9f9b36f46.tar.xz
Add maybe_awaitable and fix __init__ bugs
-rw-r--r--synapse/rest/admin/__init__.py7
-rw-r--r--synapse/util/async_helpers.py29
2 files changed, 34 insertions, 2 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index f7b9483008..939418ee2b 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -44,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.users import UserAdminServlet
 from synapse.types import UserID, create_requester
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
@@ -310,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet):
                 errcode=Codes.BAD_JSON,
             )
 
-        purge_id = await self.pagination_handler.start_purge_history(
+        purge_id = self.pagination_handler.start_purge_history(
             room_id, token, delete_local_events=delete_local_events
         )
 
@@ -480,7 +481,9 @@ class ShutdownRoomRestServlet(RestServlet):
             ratelimit=False,
         )
 
-        aliases_for_room = await self.store.get_aliases_for_room(room_id)
+        aliases_for_room = await maybe_awaitable(
+            self.store.get_aliases_for_room(room_id)
+        )
 
         await self.store.update_aliases_for_room(
             room_id, new_room_id, requester_user_id
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 0d3bdd88ce..804dbca443 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union
 
 from six.moves import range
 
+import attr
+
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.python import failure
@@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     deferred.addCallbacks(success_cb, failure_cb)
 
     return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+    """Simple awaitable that returns the provided value.
+    """
+
+    value = attr.ib()
+
+    def __await__(self):
+        return self
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+    """Convert a value to an awaitable if not already an awaitable.
+    """
+
+    if hasattr(value, "__await__"):
+        return value
+
+    return DoneAwaitable(value)