diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc
index 7c2e2b18b7..dfe4c03171 100644
--- a/changelog.d/7948.misc
+++ b/changelog.d/7948.misc
@@ -1 +1 @@
-Convert push to async/await.
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/7949.misc b/changelog.d/7949.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/7949.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc
index cbba4fa826..dfe4c03171 100644
--- a/changelog.d/7951.misc
+++ b/changelog.d/7951.misc
@@ -1 +1 @@
-Convert groups and visibility code to async / await.
+Convert various parts of the codebase to async/await.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b53e8451e5..2178e623da 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -82,7 +82,7 @@ class Auth(object):
@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 0bb216419a..69b53ca2bc 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -17,8 +17,6 @@ from typing import Optional
import attr
from nacl.signing import SigningKey
-from twisted.internet import defer
-
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
@@ -95,31 +93,30 @@ class EventBuilder(object):
def is_state(self):
return self._state_key is not None
- @defer.inlineCallbacks
- def build(self, prev_event_ids):
+ async def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
Returns:
- Deferred[FrozenEvent]
+ FrozenEvent
"""
- state_ids = yield defer.ensureDeferred(
- self._state.get_current_state_ids(self.room_id, prev_event_ids)
+ state_ids = await self._state.get_current_state_ids(
+ self.room_id, prev_event_ids
)
- auth_ids = yield self._auth.compute_auth_events(self, state_ids)
+ auth_ids = await self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
- auth_events = yield self._store.add_event_hashes(auth_ids)
- prev_events = yield self._store.add_event_hashes(prev_event_ids)
+ auth_events = await self._store.add_event_hashes(auth_ids)
+ prev_events = await self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids
- old_depth = yield self._store.get_max_depth_of(prev_event_ids)
+ old_depth = await self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f94cdcbaba..cca93e3a46 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional, Union
import attr
from frozendict import frozendict
-from twisted.internet import defer
-
from synapse.appservice import ApplicationService
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
+if TYPE_CHECKING:
+ from synapse.storage.data_stores.main import DataStore
+
@attr.s(slots=True)
class EventContext:
@@ -129,8 +131,7 @@ class EventContext:
delta_ids=delta_ids,
)
- @defer.inlineCallbacks
- def serialize(self, event, store):
+ async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -146,7 +147,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
- prev_state_ids = yield self.get_prev_state_ids()
+ prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
@@ -214,8 +215,7 @@ class EventContext:
return self._state_group
- @defer.inlineCallbacks
- def get_current_state_ids(self):
+ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -224,32 +224,31 @@ class EventContext:
``rejected`` is set.
Returns:
- Deferred[dict[(str, str), str]|None]: Returns None if state_group
- is None, which happens when the associated event is an outlier.
+ Returns None if state_group is None, which happens when the associated
+ event is an outlier.
- Maps a (type, state_key) to the event ID of the state event matching
- this tuple.
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
- yield self._ensure_fetched()
+ await self._ensure_fetched()
return self._current_state_ids
- @defer.inlineCallbacks
- def get_prev_state_ids(self):
+ async def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Returns:
- Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
- yield self._ensure_fetched()
+ await self._ensure_fetched()
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -269,8 +268,8 @@ class EventContext:
return self._current_state_ids
- def _ensure_fetched(self):
- return defer.succeed(None)
+ async def _ensure_fetched(self):
+ return None
@attr.s(slots=True)
@@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
- def _ensure_fetched(self):
+ async def _ensure_fetched(self):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)
- return make_deferred_yieldable(self._fetching_state_deferred)
+ return await make_deferred_yieldable(self._fetching_state_deferred)
- @defer.inlineCallbacks
- def _fill_out_state(self):
+ async def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
- self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+ self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._event_state_key is not None:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 459132d388..2956a64234 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.types import Requester
class ThirdPartyEventRules(object):
@@ -39,76 +41,79 @@ class ThirdPartyEventRules(object):
config=config, http_client=hs.get_simple_http_client()
)
- @defer.inlineCallbacks
- def check_event_allowed(self, event, context):
+ async def check_event_allowed(
+ self, event: EventBase, context: EventContext
+ ) -> bool:
"""Check if a provided event should be allowed in the given context.
Args:
- event (synapse.events.EventBase): The event to be checked.
- context (synapse.events.snapshot.EventContext): The context of the event.
+ event: The event to be checked.
+ context: The context of the event.
Returns:
- defer.Deferred[bool]: True if the event should be allowed, False if not.
+ True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
return True
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
state_events = {}
for key, event_id in prev_state_ids.items():
- state_events[key] = yield self.store.get_event(event_id, allow_none=True)
+ state_events[key] = await self.store.get_event(event_id, allow_none=True)
- ret = yield self.third_party_rules.check_event_allowed(event, state_events)
+ ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret
- @defer.inlineCallbacks
- def on_create_room(self, requester, config, is_requester_admin):
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ) -> bool:
"""Intercept requests to create room to allow, deny or update the
request config.
Args:
- requester (Requester)
- config (dict): The creation config from the client.
- is_requester_admin (bool): If the requester is an admin
+ requester
+ config: The creation config from the client.
+ is_requester_admin: If the requester is an admin
Returns:
- defer.Deferred[bool]: Whether room creation is allowed or denied.
+ Whether room creation is allowed or denied.
"""
if self.third_party_rules is None:
return True
- ret = yield self.third_party_rules.on_create_room(
+ ret = await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
return ret
- @defer.inlineCallbacks
- def check_threepid_can_be_invited(self, medium, address, room_id):
+ async def check_threepid_can_be_invited(
+ self, medium: str, address: str, room_id: str
+ ) -> bool:
"""Check if a provided 3PID can be invited in the given room.
Args:
- medium (str): The 3PID's medium.
- address (str): The 3PID's address.
- room_id (str): The room we want to invite the threepid to.
+ medium: The 3PID's medium.
+ address: The 3PID's address.
+ room_id: The room we want to invite the threepid to.
Returns:
- defer.Deferred[bool], True if the 3PID can be invited, False if not.
+ True if the 3PID can be invited, False if not.
"""
if self.third_party_rules is None:
return True
- state_ids = yield self.store.get_filtered_current_state_ids(room_id)
- room_state_events = yield self.store.get_events(state_ids.values())
+ state_ids = await self.store.get_filtered_current_state_ids(room_id)
+ room_state_events = await self.store.get_events(state_ids.values())
state_events = {}
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = yield self.third_party_rules.check_threepid_can_be_invited(
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
return ret
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 11f0d34ec8..2d42e268c6 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
from frozendict import frozendict
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
@@ -337,8 +335,9 @@ class EventClientSerializer(object):
hs.config.experimental_msc1849_support_enabled
)
- @defer.inlineCallbacks
- def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
+ async def serialize_event(
+ self, event, time_now, bundle_aggregations=True, **kwargs
+ ):
"""Serializes a single event.
Args:
@@ -348,7 +347,7 @@ class EventClientSerializer(object):
**kwargs: Arguments to pass to `serialize_event`
Returns:
- Deferred[dict]: The serialized event
+ dict: The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
@@ -363,8 +362,8 @@ class EventClientSerializer(object):
if not event.internal_metadata.is_redacted() and (
self.experimental_msc1849_support_enabled and bundle_aggregations
):
- annotations = yield self.store.get_aggregation_groups_for_event(event_id)
- references = yield self.store.get_relations_for_event(
+ annotations = await self.store.get_aggregation_groups_for_event(event_id)
+ references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
)
@@ -378,7 +377,7 @@ class EventClientSerializer(object):
edit = None
if event.type == EventTypes.Message:
- edit = yield self.store.get_applicable_edit(event_id)
+ edit = await self.store.get_applicable_edit(event_id)
if edit:
# If there is an edit replace the content, preserving existing
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f5f683bfd4..0d7d1adcea 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler):
}
current_state_ids = await context.get_current_state_ids()
- current_state_ids = dict(current_state_ids)
+ current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index c287c4e269..ca065e819e 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""
event_payloads = []
for event, context in event_and_contexts:
- serialized_context = yield context.serialize(event, store)
+ serialized_context = yield defer.ensureDeferred(
+ context.serialize(event, store)
+ )
event_payloads.append(
{
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index c981723c1a..b30e4d5039 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event
"""
- serialized_context = yield context.serialize(event, store)
+ serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
payload = {
"event": event.get_pdu_json(),
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..0f0e1cd09b 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
- built_event = yield self._base_builder.build(prev_event_ids)
+ built_event = yield defer.ensureDeferred(
+ self._base_builder.build(prev_event_ids)
+ )
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
diff --git a/tests/test_state.py b/tests/test_state.py
index 4858e8fc59..b5c3667d2a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
|