summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12689.misc1
-rw-r--r--synapse/events/snapshot.py177
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/handlers/federation_event.py6
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/push/action_generator.py4
-rw-r--r--synapse/state/__init__.py9
-rw-r--r--synapse/storage/databases/main/events.py6
-rw-r--r--synapse/storage/persist_events.py42
-rw-r--r--tests/handlers/test_federation_event.py4
-rw-r--r--tests/storage/test_event_chain.py2
-rw-r--r--tests/test_state.py3
-rw-r--r--tests/test_visibility.py4
13 files changed, 70 insertions, 200 deletions
diff --git a/changelog.d/12689.misc b/changelog.d/12689.misc
new file mode 100644
index 0000000000..daa484ea30
--- /dev/null
+++ b/changelog.d/12689.misc
@@ -0,0 +1 @@
+Refactor `EventContext` class.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 8120c305df..9ccd24b298 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -17,11 +17,8 @@ import attr
 from frozendict import frozendict
 from typing_extensions import Literal
 
-from twisted.internet.defer import Deferred
-
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
@@ -61,6 +58,9 @@ class EventContext:
             If ``state_group`` is None (ie, the event is an outlier),
             ``state_group_before_event`` will always also be ``None``.
 
+        state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
+            then this is the delta of the state between the two groups.
+
         prev_group: If it is known, ``state_group``'s prev_group. Note that this being
             None does not necessarily mean that ``state_group`` does not have
             a prev_group!
@@ -79,73 +79,47 @@ class EventContext:
         app_service: If this event is being sent by a (local) application service, that
             app service.
 
-        _current_state_ids: The room state map, including this event - ie, the state
-            in ``state_group``.
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            Note that this is a private attribute: it should be accessed via
-            ``get_current_state_ids``. _AsyncEventContext impl calculates this
-            on-demand: it will be None until that happens.
-
-        _prev_state_ids: The room state map, excluding this event - ie, the state
-            in ``state_group_before_event``. For a non-state
-            event, this will be the same as _current_state_events.
-
-            Note that it is a completely different thing to prev_group!
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            As with _current_state_ids, this is a private attribute. It should be
-            accessed via get_prev_state_ids.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
 
+    _storage: "Storage"
     rejected: Union[Literal[False], str] = False
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
+    _state_delta_due_to_event: Optional[StateMap[str]] = None
     prev_group: Optional[int] = None
     delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
-    _current_state_ids: Optional[StateMap[str]] = None
-    _prev_state_ids: Optional[StateMap[str]] = None
-
     partial_state: bool = False
 
     @staticmethod
     def with_state(
+        storage: "Storage",
         state_group: Optional[int],
         state_group_before_event: Optional[int],
-        current_state_ids: Optional[StateMap[str]],
-        prev_state_ids: Optional[StateMap[str]],
+        state_delta_due_to_event: Optional[StateMap[str]],
         partial_state: bool,
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ) -> "EventContext":
         return EventContext(
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            storage=storage,
             state_group=state_group,
             state_group_before_event=state_group_before_event,
+            state_delta_due_to_event=state_delta_due_to_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
             partial_state=partial_state,
         )
 
     @staticmethod
-    def for_outlier() -> "EventContext":
+    def for_outlier(
+        storage: "Storage",
+    ) -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
-        return EventContext(
-            current_state_ids={},
-            prev_state_ids={},
-        )
+        return EventContext(storage=storage)
 
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
@@ -158,24 +132,14 @@ class EventContext:
             The serialized event.
         """
 
-        # We don't serialize the full state dicts, instead they get pulled out
-        # of the DB on the other side. However, the other side can't figure out
-        # the prev_state_ids, so if we're a state event we include the event
-        # id that we replaced in the state.
-        if event.is_state():
-            prev_state_ids = await self.get_prev_state_ids()
-            prev_state_id = prev_state_ids.get((event.type, event.state_key))
-        else:
-            prev_state_id = None
-
         return {
-            "prev_state_id": prev_state_id,
-            "event_type": event.type,
-            "event_state_key": event.get_state_key(),
             "state_group": self._state_group,
             "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
             "prev_group": self.prev_group,
+            "state_delta_due_to_event": _encode_state_dict(
+                self._state_delta_due_to_event
+            ),
             "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
             "partial_state": self.partial_state,
@@ -193,16 +157,16 @@ class EventContext:
         Returns:
             The event context.
         """
-        context = _AsyncEventContextImpl(
+        context = EventContext(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             storage=storage,
-            prev_state_id=input["prev_state_id"],
-            event_type=input["event_type"],
-            event_state_key=input["event_state_key"],
             state_group=input["state_group"],
             state_group_before_event=input["state_group_before_event"],
             prev_group=input["prev_group"],
+            state_delta_due_to_event=_decode_state_dict(
+                input["state_delta_due_to_event"]
+            ),
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
             partial_state=input.get("partial_state", False),
@@ -250,8 +214,15 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        await self._ensure_fetched()
-        return self._current_state_ids
+        assert self._state_delta_due_to_event is not None
+
+        prev_state_ids = await self.get_prev_state_ids()
+
+        if self._state_delta_due_to_event:
+            prev_state_ids = dict(prev_state_ids)
+            prev_state_ids.update(self._state_delta_due_to_event)
+
+        return prev_state_ids
 
     async def get_prev_state_ids(self) -> StateMap[str]:
         """
@@ -266,94 +237,10 @@ class EventContext:
             Maps a (type, state_key) to the event ID of the state event matching
             this tuple.
         """
-        await self._ensure_fetched()
-        # There *should* be previous state IDs now.
-        assert self._prev_state_ids is not None
-        return self._prev_state_ids
-
-    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
-        """Gets the current state IDs if we have them already cached.
-
-        It is an error to access this for a rejected event, since rejected state should
-        not make it into the room state. This method will raise an exception if
-        ``rejected`` is set.
-
-        Returns:
-            Returns None if we haven't cached the state or if state_group is None
-            (which happens when the associated event is an outlier).
-
-            Otherwise, returns the the current state IDs.
-        """
-        if self.rejected:
-            raise RuntimeError("Attempt to access state_ids of rejected event")
-
-        return self._current_state_ids
-
-    async def _ensure_fetched(self) -> None:
-        return None
-
-
-@attr.s(slots=True)
-class _AsyncEventContextImpl(EventContext):
-    """
-    An implementation of EventContext which fetches _current_state_ids and
-    _prev_state_ids from the database on demand.
-
-    Attributes:
-
-        _storage
-
-        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
-            None if we haven't started calculating yet
-
-        _event_type: The type of the event the context is associated with.
-
-        _event_state_key: The state_key of the event the context is associated with.
-
-        _prev_state_id: If the event associated with the context is a state event,
-            then `_prev_state_id` is the event_id of the state that was replaced.
-    """
-
-    # This needs to have a default as we're inheriting
-    _storage: "Storage" = attr.ib(default=None)
-    _prev_state_id: Optional[str] = attr.ib(default=None)
-    _event_type: str = attr.ib(default=None)
-    _event_state_key: Optional[str] = attr.ib(default=None)
-    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
-
-    async def _ensure_fetched(self) -> None:
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(self._fill_out_state)
-
-        await make_deferred_yieldable(self._fetching_state_deferred)
-
-    async def _fill_out_state(self) -> None:
-        """Called to populate the _current_state_ids and _prev_state_ids
-        attributes by loading from the database.
-        """
-        if self.state_group is None:
-            # No state group means the event is an outlier. Usually the state_ids dicts are also
-            # pre-set to empty dicts, but they get reset when the context is serialized, so set
-            # them to empty dicts again here.
-            self._current_state_ids = {}
-            self._prev_state_ids = {}
-            return
-
-        current_state_ids = await self._storage.state.get_state_ids_for_group(
-            self.state_group
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event
         )
-        # Set this separately so mypy knows current_state_ids is not None.
-        self._current_state_ids = current_state_ids
-        if self._event_state_key is not None:
-            self._prev_state_ids = dict(current_state_ids)
-
-            key = (self._event_type, self._event_state_key)
-            if self._prev_state_id:
-                self._prev_state_ids[key] = self._prev_state_id
-            else:
-                self._prev_state_ids.pop(key, None)
-        else:
-            self._prev_state_ids = current_state_ids
 
 
 def _encode_state_dict(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38dc5b1f6e..be5099b507 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -659,7 +659,7 @@ class FederationHandler:
         # in the invitee's sync stream. It is stripped out for all other local users.
         event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -848,7 +848,7 @@ class FederationHandler:
             )
         )
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -877,7 +877,7 @@ class FederationHandler:
 
         await self.federation_client.send_leave(host_list, event)
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6cf927e4ff..6d11b32b61 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1423,7 +1423,7 @@ class FederationEventHandler:
                 # we're not bothering about room state, so flag the event as an outlier.
                 event.internal_metadata.outlier = True
 
-                context = EventContext.for_outlier()
+                context = EventContext.for_outlier(self._storage)
                 try:
                     validate_event_for_room_version(room_version_obj, event)
                     check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1874,10 +1874,10 @@ class FederationEventHandler:
         )
 
         return EventContext.with_state(
+            storage=self._storage,
             state_group=state_group,
             state_group_before_event=context.state_group_before_event,
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            state_delta_due_to_event=state_updates,
             prev_group=prev_group,
             delta_ids=state_updates,
             partial_state=context.partial_state,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c28b792e6f..e47799e7f9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -757,6 +757,10 @@ class EventCreationHandler:
             The previous version of the event is returned, if it is found in the
             event context. Otherwise, None is returned.
         """
+        if event.internal_metadata.is_outlier():
+            # This can happen due to out of band memberships
+            return None
+
         prev_state_ids = await context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
@@ -1001,7 +1005,7 @@ class EventCreationHandler:
         # after it is created
         if builder.internal_metadata.outlier:
             event.internal_metadata.outlier = True
-            context = EventContext.for_outlier()
+            context = EventContext.for_outlier(self.storage)
         elif (
             event.type == EventTypes.MSC2716_INSERTION
             and state_event_ids
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 60758df016..730d9cd354 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,5 +40,9 @@ class ActionGenerator:
     async def handle_push_actions_for_event(
         self, event: EventBase, context: EventContext
     ) -> None:
+        if event.internal_metadata.is_outlier():
+            # This can happen due to out of band memberships
+            return
+
         with Measure(self.clock, "action_for_event_by_user"):
             await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index cad3b42640..54e41d5375 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -130,6 +130,7 @@ class StateHandler:
         self.state_store = hs.get_storage().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._storage = hs.get_storage()
 
     @overload
     async def get_current_state(
@@ -361,10 +362,10 @@ class StateHandler:
 
         if not event.is_state():
             return EventContext.with_state(
+                storage=self._storage,
                 state_group_before_event=state_group_before_event,
                 state_group=state_group_before_event,
-                current_state_ids=state_ids_before_event,
-                prev_state_ids=state_ids_before_event,
+                state_delta_due_to_event={},
                 prev_group=state_group_before_event_prev_group,
                 delta_ids=deltas_to_state_group_before_event,
                 partial_state=partial_state,
@@ -393,10 +394,10 @@ class StateHandler:
         )
 
         return EventContext.with_state(
+            storage=self._storage,
             state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
-            current_state_ids=state_ids_after_event,
-            prev_state_ids=state_ids_before_event,
+            state_delta_due_to_event=delta_ids,
             prev_group=state_group_before_event,
             delta_ids=delta_ids,
             partial_state=partial_state,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 6c12653bb3..f544bcfff0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -128,7 +128,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -139,8 +138,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -215,9 +212,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 97118045a1..a7f6338e05 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -487,12 +487,6 @@ class EventsPersistenceStorage:
             # extremities in each room
             new_forward_extremities: Dict[str, Set[str]] = {}
 
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room: Dict[str, StateMap[str]] = {}
-
             # map room_id->(to_delete, to_insert) where to_delete is a list
             # of type/state keys to remove from current state, and to_insert
             # is a map (type,key)->event_id giving the state delta in each
@@ -628,14 +622,8 @@ class EventsPersistenceStorage:
 
                             state_delta_for_room[room_id] = delta
 
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
             await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
-                current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
@@ -733,7 +721,8 @@ class EventsPersistenceStorage:
 
             The first state map is the full new current state and the second
             is the delta to the existing current state. If both are None then
-            there has been no change.
+            there has been no change. Either or neither can be None if there
+            has been a change.
 
             The function may prune some old entries from the set of new
             forward extremities if it's safe to do so.
@@ -743,9 +732,6 @@ class EventsPersistenceStorage:
             the new current state is only returned if we've already calculated
             it.
         """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
         # Map from (prev state group, new state group) -> delta state dict
         state_group_deltas = {}
 
@@ -759,16 +745,6 @@ class EventsPersistenceStorage:
                     )
                 continue
 
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
             if ctx.prev_group:
                 state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
 
@@ -826,18 +802,14 @@ class EventsPersistenceStorage:
             delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids, new_latest_event_ids
+                # so lets just return that.
+                return None, delta_ids, new_latest_event_ids
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = await self.state_store._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
+        state_groups_map = await self.state_store._get_state_for_groups(
+            new_state_groups
+        )
 
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 489ba57736..e64b28f28b 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -148,7 +148,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             prev_event.internal_metadata.outlier = True
             persistence = self.hs.get_storage().persistence
             self.get_success(
-                persistence.persist_event(prev_event, EventContext.for_outlier())
+                persistence.persist_event(
+                    prev_event, EventContext.for_outlier(self.hs.get_storage())
+                )
             )
         else:
 
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 401020fd63..c7661e7186 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
             # We need to persist the events to the events and state_events
             # tables.
             persist_events_store._store_event_txn(
-                txn, [(e, EventContext()) for e in events]
+                txn, [(e, EventContext(self.hs.get_storage())) for e in events]
             )
 
             # Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index e4baa69137..651ec1c7d4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -88,6 +88,9 @@ class _DummyStore:
 
         return groups
 
+    async def get_state_ids_for_group(self, state_group):
+        return self._group_to_state[state_group]
+
     async def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
     ):
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0230f9ebb..7a9b01ef9d 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -234,7 +234,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
         event.internal_metadata.outlier = True
         self.get_success(
-            self.storage.persistence.persist_event(event, EventContext.for_outlier())
+            self.storage.persistence.persist_event(
+                event, EventContext.for_outlier(self.storage)
+            )
         )
         return event