summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/6564.misc1
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/events/snapshot.py36
-rw-r--r--synapse/events/third_party_rules.py2
-rw-r--r--synapse/handlers/_base.py2
-rw-r--r--synapse/handlers/federation.py14
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/replication/http/federation.py5
-rw-r--r--synapse/replication/http/send_event.py3
-rw-r--r--synapse/storage/data_stores/main/push_rule.py2
-rw-r--r--synapse/storage/data_stores/main/roommember.py2
-rw-r--r--tests/test_state.py28
15 files changed, 64 insertions, 53 deletions
diff --git a/changelog.d/6564.misc b/changelog.d/6564.misc
new file mode 100644
index 0000000000..f644f5868b
--- /dev/null
+++ b/changelog.d/6564.misc
@@ -0,0 +1 @@
+Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9fd52a8c77..abbc7079a3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -79,7 +79,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def check_from_context(self, room_version, event, context, do_sig_check=True):
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         auth_events_ids = yield self.compute_auth_events(
             event, prev_state_ids, for_verification=True
         )
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 64e898f40c..a44baea365 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -149,7 +149,7 @@ class EventContext:
         # the prev_state_ids, so if we're a state event we include the event
         # id that we replaced in the state.
         if event.is_state():
-            prev_state_ids = yield self.get_prev_state_ids(store)
+            prev_state_ids = yield self.get_prev_state_ids()
             prev_state_id = prev_state_ids.get((event.type, event.state_key))
         else:
             prev_state_id = None
@@ -167,12 +167,13 @@ class EventContext:
         }
 
     @staticmethod
-    def deserialize(store, input):
+    def deserialize(storage, input):
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
 
         Args:
-            store (DataStore): Used to convert AS ID to AS object
+            storage (Storage): Used to convert AS ID to AS object and fetch
+                state.
             input (dict): A dict produced by `serialize`
 
         Returns:
@@ -181,6 +182,7 @@ class EventContext:
         context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
+            storage=storage,
             prev_state_id=input["prev_state_id"],
             event_type=input["event_type"],
             event_state_key=input["event_state_key"],
@@ -193,7 +195,7 @@ class EventContext:
 
         app_service_id = input["app_service_id"]
         if app_service_id:
-            context.app_service = store.get_app_service_by_id(app_service_id)
+            context.app_service = storage.main.get_app_service_by_id(app_service_id)
 
         return context
 
@@ -216,7 +218,7 @@ class EventContext:
         return self._state_group
 
     @defer.inlineCallbacks
-    def get_current_state_ids(self, store):
+    def get_current_state_ids(self):
         """
         Gets the room state map, including this event - ie, the state in ``state_group``
 
@@ -234,11 +236,11 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        yield self._ensure_fetched(store)
+        yield self._ensure_fetched()
         return self._current_state_ids
 
     @defer.inlineCallbacks
-    def get_prev_state_ids(self, store):
+    def get_prev_state_ids(self):
         """
         Gets the room state map, excluding this event.
 
@@ -250,7 +252,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-        yield self._ensure_fetched(store)
+        yield self._ensure_fetched()
         return self._prev_state_ids
 
     def get_cached_current_state_ids(self):
@@ -270,7 +272,7 @@ class EventContext:
 
         return self._current_state_ids
 
-    def _ensure_fetched(self, store):
+    def _ensure_fetched(self):
         return defer.succeed(None)
 
 
@@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext):
 
     Attributes:
 
+        _storage (Storage)
+
         _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
             been calculated. None if we haven't started calculating yet
 
@@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext):
             that was replaced.
     """
 
+    # This needs to have a default as we're inheriting
+    _storage = attr.ib(default=None)
     _prev_state_id = attr.ib(default=None)
     _event_type = attr.ib(default=None)
     _event_state_key = attr.ib(default=None)
     _fetching_state_deferred = attr.ib(default=None)
 
-    def _ensure_fetched(self, store):
+    def _ensure_fetched(self):
         if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
+            self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
         return make_deferred_yieldable(self._fetching_state_deferred)
 
     @defer.inlineCallbacks
-    def _fill_out_state(self, store):
+    def _fill_out_state(self):
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         """
         if self.state_group is None:
             return
 
-        self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
+        self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+            self.state_group
+        )
         if self._prev_state_id and self._event_state_key is not None:
             self._prev_state_ids = dict(self._current_state_ids)
 
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 714a9b1579..86f7e5f8aa 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -53,7 +53,7 @@ class ThirdPartyEventRules(object):
         if self.third_party_rules is None:
             return True
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         # Retrieve the state events from the database.
         state_events = {}
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d15c6282fb..51413d910e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -134,7 +134,7 @@ class BaseHandler(object):
             guest_access = event.content.get("guest_access", "forbidden")
             if guest_access != "can_join":
                 if context:
-                    current_state_ids = yield context.get_current_state_ids(self.store)
+                    current_state_ids = yield context.get_current_state_ids()
                     current_state = yield self.store.get_events(
                         list(current_state_ids.values())
                     )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 60bb00fc6a..05ae40dde7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -718,7 +718,7 @@ class FederationHandler(BaseHandler):
                 # changing their profile info.
                 newly_joined = True
 
-                prev_state_ids = await context.get_prev_state_ids(self.store)
+                prev_state_ids = await context.get_prev_state_ids()
 
                 prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
@@ -1418,7 +1418,7 @@ class FederationHandler(BaseHandler):
                 user = UserID.from_string(event.state_key)
                 yield self.user_joined_room(user, event.room_id)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
         auth_chain = yield self.store.get_auth_chain(state_ids)
@@ -1927,7 +1927,7 @@ class FederationHandler(BaseHandler):
         context = yield self.state_handler.compute_event_context(event, old_state=state)
 
         if not auth_events:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             auth_events_ids = yield self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
@@ -2336,12 +2336,12 @@ class FederationHandler(BaseHandler):
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         current_state_ids = dict(current_state_ids)
 
         current_state_ids.update(state_updates)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         prev_state_ids = dict(prev_state_ids)
 
         prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
@@ -2625,7 +2625,7 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
             original_invite = yield self.store.get_event(
@@ -2673,7 +2673,7 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bf9add7fe2..4ad752205f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -515,7 +515,7 @@ class EventCreationHandler(object):
             # federation as well as those created locally. As of room v3, aliases events
             # can be created by users that are not in the room, therefore we have to
             # tolerate them in event_auth.check().
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
             prev_event = (
                 yield self.store.get_event(prev_event_id, allow_none=True)
@@ -665,7 +665,7 @@ class EventCreationHandler(object):
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
             return
@@ -914,7 +914,7 @@ class EventCreationHandler(object):
                 def is_inviter_member_event(e):
                     return e.type == EventTypes.Member and e.sender == event.sender
 
-                current_state_ids = yield context.get_current_state_ids(self.store)
+                current_state_ids = yield context.get_current_state_ids()
 
                 state_to_include_ids = [
                     e_id
@@ -967,7 +967,7 @@ class EventCreationHandler(object):
                 if original_event.room_id != event.room_id:
                     raise SynapseError(400, "Cannot redact event from a different room")
 
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             auth_events_ids = yield self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
@@ -989,7 +989,7 @@ class EventCreationHandler(object):
                 event.internal_metadata.recheck_redaction = False
 
         if event.type == EventTypes.Create:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
+            prev_state_ids = yield context.get_prev_state_ids()
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d3a1a7b4a6..89c9118b26 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
             requester, tombstone_event, tombstone_context
         )
 
-        old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+        old_room_state = yield tombstone_context.get_current_state_ids()
 
         # update any aliases
         yield self._move_aliases_to_new_room(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7b7270fc61..44c5e3239c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -193,7 +193,7 @@ class RoomMemberHandler(object):
             requester, event, context, extra_users=[target], ratelimit=ratelimit
         )
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
@@ -601,7 +601,7 @@ class RoomMemberHandler(object):
         if prev_event is not None:
             return
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         if event.membership == Membership.JOIN:
             if requester.is_guest:
                 guest_can_join = yield self._can_guest_join(prev_state_ids)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7881780760..7d9f5a38d9 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object):
 
     @defer.inlineCallbacks
     def _get_power_levels_and_sender_level(self, event, context):
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         pl_event_id = prev_state_ids.get(POWER_KEY)
         if pl_event_id:
             # fastpath: if there's a power level event, that's all we need, and
@@ -304,7 +304,7 @@ class RulesForRoom(object):
 
                 push_rules_delta_state_cache_metric.inc_hits()
             else:
-                current_state_ids = yield context.get_current_state_ids(self.store)
+                current_state_ids = yield context.get_current_state_ids()
                 push_rules_delta_state_cache_metric.inc_misses()
 
             push_rules_state_size_counter.inc(len(current_state_ids))
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 9af4e7e173..49a3251372 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
         self.federation_handler = hs.get_handlers().federation_handler
 
@@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 EventType = event_type_from_format_version(format_ver)
                 event = EventType(event_dict, internal_metadata, rejected_reason)
 
-                context = EventContext.deserialize(self.store, event_payload["context"])
+                context = EventContext.deserialize(
+                    self.storage, event_payload["context"]
+                )
 
                 event_and_contexts.append((event, context))
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 9bafd60b14..84b92f16ad 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
 
     @staticmethod
@@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             event = EventType(event_dict, internal_metadata, rejected_reason)
 
             requester = Requester.deserialize(self.store, content["requester"])
-            context = EventContext.deserialize(self.store, content["context"])
+            context = EventContext.deserialize(self.storage, content["context"])
 
             ratelimit = content["ratelimit"]
             extra_users = [UserID.from_string(u) for u in content["extra_users"]]
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 5ba13aa973..e2673ae073 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -244,7 +244,7 @@ class PushRulesWorkerStore(
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield context.get_current_state_ids(self)
+        current_state_ids = yield context.get_current_state_ids()
         result = yield self._bulk_get_push_rules_for_room(
             event.room_id, state_group, current_state_ids, event=event
         )
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 92e3b9c512..70ff5751b6 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield context.get_current_state_ids(self)
+        current_state_ids = yield context.get_current_state_ids()
         result = yield self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
diff --git a/tests/test_state.py b/tests/test_state.py
index 176535947a..e0aae06be4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertEqual(2, len(prev_state_ids))
 
         self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertSetEqual(
             {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
         )
@@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_e = context_store["E"]
 
-        prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_e.get_prev_state_ids()
         self.assertSetEqual(
             {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
         )
@@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase):
         ctx_b = context_store["B"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
         )
@@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         self.assertCountEqual(
             (e.event_id for e in old_state), current_state_ids.values()
         )
@@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         self.assertCountEqual(
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
@@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(
             set([e.event_id for e in old_state]), set(current_state_ids.values())
@@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         self.assertEqual(
             set([e.event_id for e in old_state]), set(prev_state_ids.values())
@@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
@@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])