summary refs log tree commit diff
path: root/synapse/events/snapshot.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/events/snapshot.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 64e898f40c..a44baea365 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -149,7 +149,7 @@ class EventContext:
         # the prev_state_ids, so if we're a state event we include the event
         # id that we replaced in the state.
         if event.is_state():
-            prev_state_ids = yield self.get_prev_state_ids(store)
+            prev_state_ids = yield self.get_prev_state_ids()
             prev_state_id = prev_state_ids.get((event.type, event.state_key))
         else:
             prev_state_id = None
@@ -167,12 +167,13 @@ class EventContext:
         }
 
     @staticmethod
-    def deserialize(store, input):
+    def deserialize(storage, input):
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
 
         Args:
-            store (DataStore): Used to convert AS ID to AS object
+            storage (Storage): Used to convert AS ID to AS object and fetch
+                state.
             input (dict): A dict produced by `serialize`
 
         Returns:
@@ -181,6 +182,7 @@ class EventContext:
         context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
+            storage=storage,
             prev_state_id=input["prev_state_id"],
             event_type=input["event_type"],
             event_state_key=input["event_state_key"],
@@ -193,7 +195,7 @@ class EventContext:
 
         app_service_id = input["app_service_id"]
         if app_service_id:
-            context.app_service = store.get_app_service_by_id(app_service_id)
+            context.app_service = storage.main.get_app_service_by_id(app_service_id)
 
         return context
 
@@ -216,7 +218,7 @@ class EventContext:
         return self._state_group
 
     @defer.inlineCallbacks
-    def get_current_state_ids(self, store):
+    def get_current_state_ids(self):
         """
         Gets the room state map, including this event - ie, the state in ``state_group``
 
@@ -234,11 +236,11 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        yield self._ensure_fetched(store)
+        yield self._ensure_fetched()
         return self._current_state_ids
 
     @defer.inlineCallbacks
-    def get_prev_state_ids(self, store):
+    def get_prev_state_ids(self):
         """
         Gets the room state map, excluding this event.
 
@@ -250,7 +252,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-        yield self._ensure_fetched(store)
+        yield self._ensure_fetched()
         return self._prev_state_ids
 
     def get_cached_current_state_ids(self):
@@ -270,7 +272,7 @@ class EventContext:
 
         return self._current_state_ids
 
-    def _ensure_fetched(self, store):
+    def _ensure_fetched(self):
         return defer.succeed(None)
 
 
@@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext):
 
     Attributes:
 
+        _storage (Storage)
+
         _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
             been calculated. None if we haven't started calculating yet
 
@@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext):
             that was replaced.
     """
 
+    # This needs to have a default as we're inheriting
+    _storage = attr.ib(default=None)
     _prev_state_id = attr.ib(default=None)
     _event_type = attr.ib(default=None)
     _event_state_key = attr.ib(default=None)
     _fetching_state_deferred = attr.ib(default=None)
 
-    def _ensure_fetched(self, store):
+    def _ensure_fetched(self):
         if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
+            self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
         return make_deferred_yieldable(self._fetching_state_deferred)
 
     @defer.inlineCallbacks
-    def _fill_out_state(self, store):
+    def _fill_out_state(self):
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         """
         if self.state_group is None:
             return
 
-        self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
+        self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+            self.state_group
+        )
         if self._prev_state_id and self._event_state_key is not None:
             self._prev_state_ids = dict(self._current_state_ids)