diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 64e898f40c..a44baea365 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -149,7 +149,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
- prev_state_ids = yield self.get_prev_state_ids(store)
+ prev_state_ids = yield self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
@@ -167,12 +167,13 @@ class EventContext:
}
@staticmethod
- def deserialize(store, input):
+ def deserialize(storage, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
Args:
- store (DataStore): Used to convert AS ID to AS object
+ storage (Storage): Used to convert AS ID to AS object and fetch
+ state.
input (dict): A dict produced by `serialize`
Returns:
@@ -181,6 +182,7 @@ class EventContext:
context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
+ storage=storage,
prev_state_id=input["prev_state_id"],
event_type=input["event_type"],
event_state_key=input["event_state_key"],
@@ -193,7 +195,7 @@ class EventContext:
app_service_id = input["app_service_id"]
if app_service_id:
- context.app_service = store.get_app_service_by_id(app_service_id)
+ context.app_service = storage.main.get_app_service_by_id(app_service_id)
return context
@@ -216,7 +218,7 @@ class EventContext:
return self._state_group
@defer.inlineCallbacks
- def get_current_state_ids(self, store):
+ def get_current_state_ids(self):
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -234,11 +236,11 @@ class EventContext:
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
- yield self._ensure_fetched(store)
+ yield self._ensure_fetched()
return self._current_state_ids
@defer.inlineCallbacks
- def get_prev_state_ids(self, store):
+ def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.
@@ -250,7 +252,7 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
- yield self._ensure_fetched(store)
+ yield self._ensure_fetched()
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -270,7 +272,7 @@ class EventContext:
return self._current_state_ids
- def _ensure_fetched(self, store):
+ def _ensure_fetched(self):
return defer.succeed(None)
@@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext):
Attributes:
+ _storage (Storage)
+
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
@@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext):
that was replaced.
"""
+ # This needs to have a default as we're inheriting
+ _storage = attr.ib(default=None)
_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
- def _ensure_fetched(self, store):
+ def _ensure_fetched(self):
if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
+ self._fetching_state_deferred = run_in_background(self._fill_out_state)
return make_deferred_yieldable(self._fetching_state_deferred)
@defer.inlineCallbacks
- def _fill_out_state(self, store):
+ def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
- self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
+ self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+ self.state_group
+ )
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 714a9b1579..86f7e5f8aa 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -53,7 +53,7 @@ class ThirdPartyEventRules(object):
if self.third_party_rules is None:
return True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
# Retrieve the state events from the database.
state_events = {}
|