summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7951.misc1
-rw-r--r--synapse/groups/attestations.py25
-rw-r--r--synapse/visibility.py30
-rw-r--r--tests/test_visibility.py12
4 files changed, 31 insertions, 37 deletions
diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc
new file mode 100644
index 0000000000..cbba4fa826
--- /dev/null
+++ b/changelog.d/7951.misc
@@ -0,0 +1 @@
+Convert groups and visibility code to async / await.
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index dab13c243f..e674bf44a2 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -41,8 +41,6 @@ from typing import Tuple
 
 from signedjson.sign import sign_json
 
-from twisted.internet import defer
-
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
@@ -72,8 +70,9 @@ class GroupAttestationSigning(object):
         self.server_name = hs.hostname
         self.signing_key = hs.signing_key
 
-    @defer.inlineCallbacks
-    def verify_attestation(self, attestation, group_id, user_id, server_name=None):
+    async def verify_attestation(
+        self, attestation, group_id, user_id, server_name=None
+    ):
         """Verifies that the given attestation matches the given parameters.
 
         An optional server_name can be supplied to explicitly set which server's
@@ -102,7 +101,7 @@ class GroupAttestationSigning(object):
         if valid_until_ms < now:
             raise SynapseError(400, "Attestation expired")
 
-        yield self.keyring.verify_json_for_server(
+        await self.keyring.verify_json_for_server(
             server_name, attestation, now, "Group attestation"
         )
 
@@ -142,8 +141,7 @@ class GroupAttestionRenewer(object):
                 self._start_renew_attestations, 30 * 60 * 1000
             )
 
-    @defer.inlineCallbacks
-    def on_renew_attestation(self, group_id, user_id, content):
+    async def on_renew_attestation(self, group_id, user_id, content):
         """When a remote updates an attestation
         """
         attestation = content["attestation"]
@@ -151,11 +149,11 @@ class GroupAttestionRenewer(object):
         if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
             raise SynapseError(400, "Neither user not group are on this server")
 
-        yield self.attestations.verify_attestation(
+        await self.attestations.verify_attestation(
             attestation, user_id=user_id, group_id=group_id
         )
 
-        yield self.store.update_remote_attestion(group_id, user_id, attestation)
+        await self.store.update_remote_attestion(group_id, user_id, attestation)
 
         return {}
 
@@ -172,8 +170,7 @@ class GroupAttestionRenewer(object):
             now + UPDATE_ATTESTATION_TIME_MS
         )
 
-        @defer.inlineCallbacks
-        def _renew_attestation(group_user: Tuple[str, str]):
+        async def _renew_attestation(group_user: Tuple[str, str]):
             group_id, user_id = group_user
             try:
                 if not self.is_mine_id(group_id):
@@ -186,16 +183,16 @@ class GroupAttestionRenewer(object):
                         user_id,
                         group_id,
                     )
-                    yield self.store.remove_attestation_renewal(group_id, user_id)
+                    await self.store.remove_attestation_renewal(group_id, user_id)
                     return
 
                 attestation = self.attestations.create_attestation(group_id, user_id)
 
-                yield self.transport_client.renew_group_attestation(
+                await self.transport_client.renew_group_attestation(
                     destination, group_id, user_id, content={"attestation": attestation}
                 )
 
-                yield self.store.update_attestation_renewal(
+                await self.store.update_attestation_renewal(
                     group_id, user_id, attestation
                 )
             except (RequestSendFailed, HttpResponseException) as e:
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 0f042c5696..e3da7744d2 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,8 +16,6 @@
 import logging
 import operator
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.utils import prune_event
 from synapse.storage import Storage
@@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = (
 )
 
 
-@defer.inlineCallbacks
-def filter_events_for_client(
+async def filter_events_for_client(
     storage: Storage,
     user_id,
     events,
@@ -67,19 +64,19 @@ def filter_events_for_client(
             also be called to check whether a user can see the state at a given point.
 
     Returns:
-        Deferred[list[synapse.events.EventBase]]
+        list[synapse.events.EventBase]
     """
     # Filter out events that have been soft failed so that we don't relay them
     # to clients.
     events = [e for e in events if not e.internal_metadata.is_soft_failed()]
 
     types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
-    event_id_to_state = yield storage.state.get_state_for_events(
+    event_id_to_state = await storage.state.get_state_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(types),
     )
 
-    ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
+    ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
         "m.ignored_user_list", user_id
     )
 
@@ -90,7 +87,7 @@ def filter_events_for_client(
         else []
     )
 
-    erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
+    erased_senders = await storage.main.are_users_erased((e.sender for e in events))
 
     if filter_send_to_client:
         room_ids = {e.room_id for e in events}
@@ -99,7 +96,7 @@ def filter_events_for_client(
         for room_id in room_ids:
             retention_policies[
                 room_id
-            ] = yield storage.main.get_retention_policy_for_room(room_id)
+            ] = await storage.main.get_retention_policy_for_room(room_id)
 
     def allowed(event):
         """
@@ -254,8 +251,7 @@ def filter_events_for_client(
     return list(filtered_events)
 
 
-@defer.inlineCallbacks
-def filter_events_for_server(
+async def filter_events_for_server(
     storage: Storage,
     server_name,
     events,
@@ -277,7 +273,7 @@ def filter_events_for_server(
             backfill or not.
 
     Returns
-        Deferred[list[FrozenEvent]]
+        list[FrozenEvent]
     """
 
     def is_sender_erased(event, erased_senders):
@@ -321,7 +317,7 @@ def filter_events_for_server(
     # Lets check to see if all the events have a history visibility
     # of "shared" or "world_readable". If that's the case then we don't
     # need to check membership (as we know the server is in the room).
-    event_to_state_ids = yield storage.state.get_state_ids_for_events(
+    event_to_state_ids = await storage.state.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(
             types=((EventTypes.RoomHistoryVisibility, ""),)
@@ -339,14 +335,14 @@ def filter_events_for_server(
     if not visibility_ids:
         all_open = True
     else:
-        event_map = yield storage.main.get_events(visibility_ids)
+        event_map = await storage.main.get_events(visibility_ids)
         all_open = all(
             e.content.get("history_visibility") in (None, "shared", "world_readable")
             for e in event_map.values()
         )
 
     if not check_history_visibility_only:
-        erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
+        erased_senders = await storage.main.are_users_erased((e.sender for e in events))
     else:
         # We don't want to check whether users are erased, which is equivalent
         # to no users having been erased.
@@ -375,7 +371,7 @@ def filter_events_for_server(
 
     # first, for each event we're wanting to return, get the event_ids
     # of the history vis and membership state at those events.
-    event_to_state_ids = yield storage.state.get_state_ids_for_events(
+    event_to_state_ids = await storage.state.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(
             types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@@ -405,7 +401,7 @@ def filter_events_for_server(
             return False
         return state_key[idx + 1 :] == server_name
 
-    event_map = yield storage.main.get_events(
+    event_map = await storage.main.get_events(
         [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
     )
 
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index b371efc0df..a7a36174ea 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             evt = yield self.inject_room_member(user, extra_content={"a": "b"})
             events_to_filter.append(evt)
 
-        filtered = yield filter_events_for_server(
-            self.storage, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(self.storage, "test_server", events_to_filter)
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
 
         # ... and the filtering happens.
-        filtered = yield filter_events_for_server(
-            self.storage, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(self.storage, "test_server", events_to_filter)
         )
 
         for i in range(0, len(events_to_filter)):
@@ -265,8 +265,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         storage.main = test_store
         storage.state = test_store
 
-        filtered = yield filter_events_for_server(
-            test_store, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(test_store, "test_server", events_to_filter)
         )
         logger.info("Filtering took %f seconds", time.time() - start)