diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 5e02ef1a5c..f9568638a1 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -19,18 +19,12 @@ from frozendict import frozendict
from twisted.internet import defer
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
class EventContext(object):
"""
Attributes:
- current_state_ids (dict[(str, str), str]):
- The current state map including the current event.
- (type, state_key) -> event_id
-
- prev_state_ids (dict[(str, str), str]):
- The current state map excluding the current event.
- (type, state_key) -> event_id
-
state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
@@ -47,36 +41,71 @@ class EventContext(object):
prev_state_events (?): XXX: is this ever set to anything other than
the empty list?
+
+ _current_state_ids (dict[(str, str), str]|None):
+ The current state map including the current event. None if outlier
+ or we haven't fetched the state from DB yet.
+ (type, state_key) -> event_id
+
+ _prev_state_ids (dict[(str, str), str]|None):
+ The current state map excluding the current event. None if outlier
+ or we haven't fetched the state from DB yet.
+ (type, state_key) -> event_id
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _prev_state_id (str|None): If set then the event associated with the
+ context overrode the _prev_state_id
+
+ _event_type (str): The type of the event the context is associated with
+
+ _event_state_key (str|None): The state_key of the event the context is
+ associated with
"""
__slots__ = [
- "current_state_ids",
- "prev_state_ids",
"state_group",
"rejected",
"prev_group",
"delta_ids",
"prev_state_events",
"app_service",
+ "_current_state_ids",
+ "_prev_state_ids",
+ "_prev_state_id",
+ "_event_type",
+ "_event_state_key",
+ "_fetching_state_deferred",
]
- def __init__(self, state_group, current_state_ids, prev_state_ids,
- prev_group=None, delta_ids=None):
+ @staticmethod
+ def with_state(state_group, current_state_ids, prev_state_ids,
+ prev_group=None, delta_ids=None):
+ context = EventContext()
+
# The current state including the current event
- self.current_state_ids = current_state_ids
+ context._current_state_ids = current_state_ids
# The current state excluding the current event
- self.prev_state_ids = prev_state_ids
- self.state_group = state_group
+ context._prev_state_ids = prev_state_ids
+ context.state_group = state_group
+
+ context._prev_state_id = None
+ context._event_type = None
+ context._event_state_key = None
+ context._fetching_state_deferred = defer.succeed(None)
# A previously persisted state group and a delta between that
# and this state.
- self.prev_group = prev_group
- self.delta_ids = delta_ids
+ context.prev_group = prev_group
+ context.delta_ids = delta_ids
+
+ context.prev_state_events = []
- self.prev_state_events = []
+ context.rejected = False
+ context.app_service = None
- self.rejected = False
- self.app_service = None
+ return context
def serialize(self, event):
"""Converts self to a type that can be serialized as JSON, and then
@@ -123,30 +152,17 @@ class EventContext(object):
Returns:
EventContext
"""
+ context = EventContext()
+
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
- prev_state_id = input["prev_state_id"]
- event_type = input["event_type"]
- event_state_key = input["event_state_key"]
+ context._prev_state_id = input["prev_state_id"]
+ context._event_type = input["event_type"]
+ context._event_state_key = input["event_state_key"]
- state_group = input["state_group"]
-
- current_state_ids = yield store.get_state_ids_for_group(
- state_group,
- )
- if prev_state_id and event_state_key:
- prev_state_ids = dict(current_state_ids)
- prev_state_ids[(event_type, event_state_key)] = prev_state_id
- else:
- prev_state_ids = current_state_ids
-
- context = EventContext(
- state_group=state_group,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- prev_group=input["prev_group"],
- delta_ids=_decode_state_dict(input["delta_ids"]),
- )
+ context.state_group = input["state_group"]
+ context.prev_group = input["prev_group"]
+ context.delta_ids = _decode_state_dict(input["delta_ids"])
context.rejected = input["rejected"]
context.prev_state_events = input["prev_state_events"]
@@ -157,6 +173,61 @@ class EventContext(object):
defer.returnValue(context)
+ @defer.inlineCallbacks
+ def get_current_state_ids(self, store):
+ """Gets the current state IDs
+
+ Returns:
+ Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ is None, which happens when the associated event is an outlier.
+ """
+
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store,
+ )
+
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ defer.returnValue(self._current_state_ids)
+
+ @defer.inlineCallbacks
+ def get_prev_state_ids(self, store):
+ """Gets the prev state IDs
+
+ Returns:
+ Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ is None, which happens when the associated event is an outlier.
+ """
+
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store,
+ )
+
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ defer.returnValue(self._prev_state_ids)
+
+ @defer.inlineCallbacks
+ def _fill_out_state(self, store):
+ """Called to populate the _current_state_ids and _prev_state_ids
+ attributes by loading from the database.
+ """
+ if self.state_group is None:
+ return
+
+ self._current_state_ids = yield store.get_state_ids_for_group(
+ self.state_group,
+ )
+ if self._prev_state_id and self._event_state_key is not None:
+ self._prev_state_ids = dict(self._current_state_ids)
+
+ key = (self._event_type, self._event_state_key)
+ self._prev_state_ids[key] = self._prev_state_id
+ else:
+ self._prev_state_ids = self._current_state_ids
+
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|