summary refs log tree commit diff
path: root/synapse/events/snapshot.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/snapshot.py')
-rw-r--r--synapse/events/snapshot.py175
1 files changed, 134 insertions, 41 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 5e02ef1a5c..a6d7bf5700 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -19,18 +19,12 @@ from frozendict import frozendict
 
 from twisted.internet import defer
 
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
 
 class EventContext(object):
     """
     Attributes:
-        current_state_ids (dict[(str, str), str]):
-            The current state map including the current event.
-            (type, state_key) -> event_id
-
-        prev_state_ids (dict[(str, str), str]):
-            The current state map excluding the current event.
-            (type, state_key) -> event_id
-
         state_group (int|None): state group id, if the state has been stored
             as a state group. This is usually only None if e.g. the event is
             an outlier.
@@ -47,36 +41,74 @@ class EventContext(object):
 
         prev_state_events (?): XXX: is this ever set to anything other than
             the empty list?
+
+        _current_state_ids (dict[(str, str), str]|None):
+            The current state map including the current event. None if outlier
+            or we haven't fetched the state from DB yet.
+            (type, state_key) -> event_id
+
+        _prev_state_ids (dict[(str, str), str]|None):
+            The current state map excluding the current event. None if outlier
+            or we haven't fetched the state from DB yet.
+            (type, state_key) -> event_id
+
+        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+            been calculated. None if we haven't started calculating yet
+
+        _event_type (str): The type of the event the context is associated with.
+            Only set when state has not been fetched yet.
+
+        _event_state_key (str|None): The state_key of the event the context is
+            associated with. Only set when state has not been fetched yet.
+
+        _prev_state_id (str|None): If the event associated with the context is
+            a state event, then `_prev_state_id` is the event_id of the state
+            that was replaced.
+            Only set when state has not been fetched yet.
     """
 
     __slots__ = [
-        "current_state_ids",
-        "prev_state_ids",
         "state_group",
         "rejected",
         "prev_group",
         "delta_ids",
         "prev_state_events",
         "app_service",
+        "_current_state_ids",
+        "_prev_state_ids",
+        "_prev_state_id",
+        "_event_type",
+        "_event_state_key",
+        "_fetching_state_deferred",
     ]
 
-    def __init__(self, state_group, current_state_ids, prev_state_ids,
-                 prev_group=None, delta_ids=None):
+    @staticmethod
+    def with_state(state_group, current_state_ids, prev_state_ids,
+                   prev_group=None, delta_ids=None):
+        context = EventContext()
+
         # The current state including the current event
-        self.current_state_ids = current_state_ids
+        context._current_state_ids = current_state_ids
         # The current state excluding the current event
-        self.prev_state_ids = prev_state_ids
-        self.state_group = state_group
+        context._prev_state_ids = prev_state_ids
+        context.state_group = state_group
+
+        context._prev_state_id = None
+        context._event_type = None
+        context._event_state_key = None
+        context._fetching_state_deferred = defer.succeed(None)
 
         # A previously persisted state group and a delta between that
         # and this state.
-        self.prev_group = prev_group
-        self.delta_ids = delta_ids
+        context.prev_group = prev_group
+        context.delta_ids = delta_ids
 
-        self.prev_state_events = []
+        context.prev_state_events = []
 
-        self.rejected = False
-        self.app_service = None
+        context.rejected = False
+        context.app_service = None
+
+        return context
 
     def serialize(self, event):
         """Converts self to a type that can be serialized as JSON, and then
@@ -123,30 +155,17 @@ class EventContext(object):
         Returns:
             EventContext
         """
+        context = EventContext()
+
         # We use the state_group and prev_state_id stuff to pull the
         # current_state_ids out of the DB and construct prev_state_ids.
-        prev_state_id = input["prev_state_id"]
-        event_type = input["event_type"]
-        event_state_key = input["event_state_key"]
-
-        state_group = input["state_group"]
+        context._prev_state_id = input["prev_state_id"]
+        context._event_type = input["event_type"]
+        context._event_state_key = input["event_state_key"]
 
-        current_state_ids = yield store.get_state_ids_for_group(
-            state_group,
-        )
-        if prev_state_id and event_state_key:
-            prev_state_ids = dict(current_state_ids)
-            prev_state_ids[(event_type, event_state_key)] = prev_state_id
-        else:
-            prev_state_ids = current_state_ids
-
-        context = EventContext(
-            state_group=state_group,
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
-            prev_group=input["prev_group"],
-            delta_ids=_decode_state_dict(input["delta_ids"]),
-        )
+        context.state_group = input["state_group"]
+        context.prev_group = input["prev_group"]
+        context.delta_ids = _decode_state_dict(input["delta_ids"])
 
         context.rejected = input["rejected"]
         context.prev_state_events = input["prev_state_events"]
@@ -157,6 +176,80 @@ class EventContext(object):
 
         defer.returnValue(context)
 
+    @defer.inlineCallbacks
+    def get_current_state_ids(self, store):
+        """Gets the current state IDs
+
+        Returns:
+            Deferred[dict[(str, str), str]|None]: Returns None if state_group
+            is None, which happens when the associated event is an outlier.
+        """
+
+        if not self._fetching_state_deferred:
+            self._fetching_state_deferred = run_in_background(
+                self._fill_out_state, store,
+            )
+
+        yield make_deferred_yieldable(self._fetching_state_deferred)
+
+        defer.returnValue(self._current_state_ids)
+
+    @defer.inlineCallbacks
+    def get_prev_state_ids(self, store):
+        """Gets the prev state IDs
+
+        Returns:
+            Deferred[dict[(str, str), str]|None]: Returns None if state_group
+            is None, which happens when the associated event is an outlier.
+        """
+
+        if not self._fetching_state_deferred:
+            self._fetching_state_deferred = run_in_background(
+                self._fill_out_state, store,
+            )
+
+        yield make_deferred_yieldable(self._fetching_state_deferred)
+
+        defer.returnValue(self._prev_state_ids)
+
+    @defer.inlineCallbacks
+    def _fill_out_state(self, store):
+        """Called to populate the _current_state_ids and _prev_state_ids
+        attributes by loading from the database.
+        """
+        if self.state_group is None:
+            return
+
+        self._current_state_ids = yield store.get_state_ids_for_group(
+            self.state_group,
+        )
+        if self._prev_state_id and self._event_state_key is not None:
+            self._prev_state_ids = dict(self._current_state_ids)
+
+            key = (self._event_type, self._event_state_key)
+            self._prev_state_ids[key] = self._prev_state_id
+        else:
+            self._prev_state_ids = self._current_state_ids
+
+    @defer.inlineCallbacks
+    def update_state(self, state_group, prev_state_ids, current_state_ids,
+                     delta_ids):
+        """Replace the state in the context
+        """
+
+        # We need to make sure we wait for any ongoing fetching of state
+        # to complete so that the updated state doesn't get clobbered
+        if self._fetching_state_deferred:
+            yield make_deferred_yieldable(self._fetching_state_deferred)
+
+        self.state_group = state_group
+        self._prev_state_ids = prev_state_ids
+        self._current_state_ids = current_state_ids
+        self.delta_ids = delta_ids
+
+        # We need to ensure that that we've marked as having fetched the state
+        self._fetching_state_deferred = defer.succeed(None)
+
 
 def _encode_state_dict(state_dict):
     """Since dicts of (type, state_key) -> event_id cannot be serialized in