summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6298.misc1
-rw-r--r--synapse/events/snapshot.py107
-rw-r--r--synapse/handlers/federation.py38
-rw-r--r--tests/test_federation.py4
4 files changed, 65 insertions, 85 deletions
diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc
new file mode 100644
index 0000000000..d4190730b2
--- /dev/null
+++ b/changelog.d/6298.misc
@@ -0,0 +1 @@
+Refactor EventContext for clarity.
\ No newline at end of file
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 27cd8a63ff..a269de5482 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -37,9 +37,6 @@ class EventContext:
         delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
             (type, state_key) -> event_id. ``None`` for an outlier.
 
-        prev_state_events (?): XXX: is this ever set to anything other than
-            the empty list?
-
         app_service: FIXME
 
         _current_state_ids (dict[(str, str), str]|None):
@@ -51,36 +48,16 @@ class EventContext:
             The current state map excluding the current event. None if outlier
             or we haven't fetched the state from DB yet.
             (type, state_key) -> event_id
-
-        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
-            been calculated. None if we haven't started calculating yet
-
-        _event_type (str): The type of the event the context is associated with.
-            Only set when state has not been fetched yet.
-
-        _event_state_key (str|None): The state_key of the event the context is
-            associated with. Only set when state has not been fetched yet.
-
-        _prev_state_id (str|None): If the event associated with the context is
-            a state event, then `_prev_state_id` is the event_id of the state
-            that was replaced.
-            Only set when state has not been fetched yet.
     """
 
     state_group = attr.ib(default=None)
     rejected = attr.ib(default=False)
     prev_group = attr.ib(default=None)
     delta_ids = attr.ib(default=None)
-    prev_state_events = attr.ib(default=attr.Factory(list))
     app_service = attr.ib(default=None)
 
-    _current_state_ids = attr.ib(default=None)
     _prev_state_ids = attr.ib(default=None)
-    _prev_state_id = attr.ib(default=None)
-
-    _event_type = attr.ib(default=None)
-    _event_state_key = attr.ib(default=None)
-    _fetching_state_deferred = attr.ib(default=None)
+    _current_state_ids = attr.ib(default=None)
 
     @staticmethod
     def with_state(
@@ -90,7 +67,6 @@ class EventContext:
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             state_group=state_group,
-            fetching_state_deferred=defer.succeed(None),
             prev_group=prev_group,
             delta_ids=delta_ids,
         )
@@ -125,7 +101,6 @@ class EventContext:
             "rejected": self.rejected,
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
-            "prev_state_events": self.prev_state_events,
             "app_service_id": self.app_service.id if self.app_service else None,
         }
 
@@ -141,7 +116,7 @@ class EventContext:
         Returns:
             EventContext
         """
-        context = EventContext(
+        context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             prev_state_id=input["prev_state_id"],
@@ -151,7 +126,6 @@ class EventContext:
             prev_group=input["prev_group"],
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
-            prev_state_events=input["prev_state_events"],
         )
 
         app_service_id = input["app_service_id"]
@@ -170,14 +144,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
-
-        yield make_deferred_yieldable(self._fetching_state_deferred)
-
+        yield self._ensure_fetched(store)
         return self._current_state_ids
 
     @defer.inlineCallbacks
@@ -190,14 +157,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
-
-        yield make_deferred_yieldable(self._fetching_state_deferred)
-
+        yield self._ensure_fetched(store)
         return self._prev_state_ids
 
     def get_cached_current_state_ids(self):
@@ -211,6 +171,44 @@ class EventContext:
 
         return self._current_state_ids
 
+    def _ensure_fetched(self, store):
+        return defer.succeed(None)
+
+
+@attr.s(slots=True)
+class _AsyncEventContextImpl(EventContext):
+    """
+    An implementation of EventContext which fetches _current_state_ids and
+    _prev_state_ids from the database on demand.
+
+    Attributes:
+
+        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+            been calculated. None if we haven't started calculating yet
+
+        _event_type (str): The type of the event the context is associated with.
+
+        _event_state_key (str): The state_key of the event the context is
+            associated with.
+
+        _prev_state_id (str|None): If the event associated with the context is
+            a state event, then `_prev_state_id` is the event_id of the state
+            that was replaced.
+    """
+
+    _prev_state_id = attr.ib(default=None)
+    _event_type = attr.ib(default=None)
+    _event_state_key = attr.ib(default=None)
+    _fetching_state_deferred = attr.ib(default=None)
+
+    def _ensure_fetched(self, store):
+        if not self._fetching_state_deferred:
+            self._fetching_state_deferred = run_in_background(
+                self._fill_out_state, store
+            )
+
+        return make_deferred_yieldable(self._fetching_state_deferred)
+
     @defer.inlineCallbacks
     def _fill_out_state(self, store):
         """Called to populate the _current_state_ids and _prev_state_ids
@@ -228,27 +226,6 @@ class EventContext:
         else:
             self._prev_state_ids = self._current_state_ids
 
-    @defer.inlineCallbacks
-    def update_state(
-        self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
-    ):
-        """Replace the state in the context
-        """
-
-        # We need to make sure we wait for any ongoing fetching of state
-        # to complete so that the updated state doesn't get clobbered
-        if self._fetching_state_deferred:
-            yield make_deferred_yieldable(self._fetching_state_deferred)
-
-        self.state_group = state_group
-        self._prev_state_ids = prev_state_ids
-        self.prev_group = prev_group
-        self._current_state_ids = current_state_ids
-        self.delta_ids = delta_ids
-
-        # We need to ensure that that we've marked as having fetched the state
-        self._fetching_state_deferred = defer.succeed(None)
-
 
 def _encode_state_dict(state_dict):
     """Since dicts of (type, state_key) -> event_id cannot be serialized in
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index dab6be9573..8cafcfdab0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import auth_types_for_event
+from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.logging.context import (
     make_deferred_yieldable,
@@ -1871,14 +1872,7 @@ class FederationHandler(BaseHandler):
                 if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
-        try:
-            yield self.do_auth(origin, event, context, auth_events=auth_events)
-        except AuthError as e:
-            logger.warning(
-                "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
-            )
-
-            context.rejected = RejectedReason.AUTH_ERROR
+        context = yield self.do_auth(origin, event, context, auth_events=auth_events)
 
         if not context.rejected:
             yield self._check_for_soft_fail(event, state, backfilled)
@@ -2047,12 +2041,12 @@ class FederationHandler(BaseHandler):
 
                 Also NB that this function adds entries to it.
         Returns:
-            defer.Deferred[None]
+            defer.Deferred[EventContext]: updated context object
         """
         room_version = yield self.store.get_room_version(event.room_id)
 
         try:
-            yield self._update_auth_events_and_context_for_auth(
+            context = yield self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
             )
         except Exception:
@@ -2070,7 +2064,9 @@ class FederationHandler(BaseHandler):
             event_auth.check(room_version, event, auth_events=auth_events)
         except AuthError as e:
             logger.warning("Failed auth resolution for %r because %s", event, e)
-            raise e
+            context.rejected = RejectedReason.AUTH_ERROR
+
+        return context
 
     @defer.inlineCallbacks
     def _update_auth_events_and_context_for_auth(
@@ -2094,7 +2090,7 @@ class FederationHandler(BaseHandler):
             auth_events (dict[(str, str)->synapse.events.EventBase]):
 
         Returns:
-            defer.Deferred[None]
+            defer.Deferred[EventContext]: updated context
         """
         event_auth_events = set(event.auth_event_ids())
 
@@ -2133,7 +2129,7 @@ class FederationHandler(BaseHandler):
                     # The other side isn't around or doesn't implement the
                     # endpoint, so lets just bail out.
                     logger.info("Failed to get event auth from remote: %s", e)
-                    return
+                    return context
 
                 seen_remotes = yield self.store.have_seen_events(
                     [e.event_id for e in remote_auth_chain]
@@ -2174,7 +2170,7 @@ class FederationHandler(BaseHandler):
 
         if event.internal_metadata.is_outlier():
             logger.info("Skipping auth_event fetch for outlier")
-            return
+            return context
 
         # FIXME: Assumes we have and stored all the state for all the
         # prev_events
@@ -2183,7 +2179,7 @@ class FederationHandler(BaseHandler):
         )
 
         if not different_auth:
-            return
+            return context
 
         logger.info(
             "auth_events refers to events which are not in our calculated auth "
@@ -2230,10 +2226,12 @@ class FederationHandler(BaseHandler):
 
             auth_events.update(new_state)
 
-            yield self._update_context_for_auth_events(
+            context = yield self._update_context_for_auth_events(
                 event, context, auth_events, event_key
             )
 
+        return context
+
     @defer.inlineCallbacks
     def _update_context_for_auth_events(self, event, context, auth_events, event_key):
         """Update the state_ids in an event context after auth event resolution,
@@ -2242,14 +2240,16 @@ class FederationHandler(BaseHandler):
         Args:
             event (Event): The event we're handling the context for
 
-            context (synapse.events.snapshot.EventContext): event context
-                to be updated
+            context (synapse.events.snapshot.EventContext): initial event context
 
             auth_events (dict[(str, str)->str]): Events to update in the event
                 context.
 
             event_key ((str, str)): (type, state_key) for the current event.
                 this will not be included in the current_state in the context.
+
+        Returns:
+            Deferred[EventContext]: new event context
         """
         state_updates = {
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@@ -2274,7 +2274,7 @@ class FederationHandler(BaseHandler):
             current_state_ids=current_state_ids,
         )
 
-        yield context.update_state(
+        return EventContext.with_state(
             state_group=state_group,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
diff --git a/tests/test_federation.py b/tests/test_federation.py
index d1acb16f30..7d82b58466 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         self.handler = self.homeserver.get_handlers().federation_handler
-        self.handler.do_auth = lambda *a, **b: succeed(True)
+        self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+            context
+        )
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
             pdus