diff options
author | Erik Johnston <erik@matrix.org> | 2019-10-11 15:26:09 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2019-10-11 15:26:09 +0100 |
commit | 3c2d6c708cd93df7fc945e10014049e9f9b36f46 (patch) | |
tree | fe45c9c825d9db33ec4fd50f57f6cf1b0a02f6e1 | |
parent | Newsfile (diff) | |
download | synapse-3c2d6c708cd93df7fc945e10014049e9f9b36f46.tar.xz |
Add maybe_awaitable and fix __init__ bugs
-rw-r--r-- | synapse/rest/admin/__init__.py | 7 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 29 |
2 files changed, 34 insertions, 2 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index f7b9483008..939418ee2b 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import UserAdminServlet from synapse.types import UserID, create_requester +from synapse.util.async_helpers import maybe_awaitable from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -310,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet): errcode=Codes.BAD_JSON, ) - purge_id = await self.pagination_handler.start_purge_history( + purge_id = self.pagination_handler.start_purge_history( room_id, token, delete_local_events=delete_local_events ) @@ -480,7 +481,9 @@ class ShutdownRoomRestServlet(RestServlet): ratelimit=False, ) - aliases_for_room = await self.store.get_aliases_for_room(room_id) + aliases_for_room = await maybe_awaitable( + self.store.get_aliases_for_room(room_id) + ) await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 0d3bdd88ce..804dbca443 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union from six.moves import range +import attr + from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure @@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred.addCallbacks(success_cb, failure_cb) return new_d + + +@attr.s(slots=True, frozen=True) +class DoneAwaitable(object): + """Simple awaitable that returns the provided value. + """ + + value = attr.ib() + + def __await__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + raise StopIteration(self.value) + + +def maybe_awaitable(value): + """Convert a value to an awaitable if not already an awaitable. + """ + + if hasattr(value, "__await__"): + return value + + return DoneAwaitable(value) |