summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/12012.misc1
-rw-r--r--synapse/events/snapshot.py9
-rw-r--r--synapse/handlers/federation.py11
-rw-r--r--synapse/handlers/federation_event.py13
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/state/__init__.py31
-rw-r--r--synapse/storage/databases/main/events.py25
-rw-r--r--synapse/storage/databases/main/events_worker.py28
-rw-r--r--synapse/storage/databases/main/room.py37
-rw-r--r--synapse/storage/schema/main/delta/68/04partial_state_rooms.sql41
-rw-r--r--synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py72
-rw-r--r--tests/test_state.py59
12 files changed, 297 insertions, 32 deletions
diff --git a/changelog.d/12012.misc b/changelog.d/12012.misc
new file mode 100644
index 0000000000..a473f41e78
--- /dev/null
+++ b/changelog.d/12012.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 5833fee25f..46042b2bf7 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -101,6 +101,9 @@ class EventContext:
 
             As with _current_state_ids, this is a private attribute. It should be
             accessed via get_prev_state_ids.
+
+        partial_state: if True, we may be storing this event with a temporary,
+            incomplete state.
     """
 
     rejected: Union[bool, str] = False
@@ -113,12 +116,15 @@ class EventContext:
     _current_state_ids: Optional[StateMap[str]] = None
     _prev_state_ids: Optional[StateMap[str]] = None
 
+    partial_state: bool = False
+
     @staticmethod
     def with_state(
         state_group: Optional[int],
         state_group_before_event: Optional[int],
         current_state_ids: Optional[StateMap[str]],
         prev_state_ids: Optional[StateMap[str]],
+        partial_state: bool,
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ) -> "EventContext":
@@ -129,6 +135,7 @@ class EventContext:
             state_group_before_event=state_group_before_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
+            partial_state=partial_state,
         )
 
     @staticmethod
@@ -170,6 +177,7 @@ class EventContext:
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
+            "partial_state": self.partial_state,
         }
 
     @staticmethod
@@ -196,6 +204,7 @@ class EventContext:
             prev_group=input["prev_group"],
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
+            partial_state=input.get("partial_state", False),
         )
 
         app_service_id = input["app_service_id"]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c055c26eca..eb03a5accb 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -519,8 +519,17 @@ class FederationHandler:
                 state_events=state,
             )
 
+            if ret.partial_state:
+                await self.store.store_partial_state_room(room_id, ret.servers_in_room)
+
             max_stream_id = await self._federation_event_handler.process_remote_join(
-                origin, room_id, auth_chain, state, event, room_version_obj
+                origin,
+                room_id,
+                auth_chain,
+                state,
+                event,
+                room_version_obj,
+                partial_state=ret.partial_state,
             )
 
             # We wait here until this instance has seen the events come down
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 09d0de1ead..4bd87709f3 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -397,6 +397,7 @@ class FederationEventHandler:
         state: List[EventBase],
         event: EventBase,
         room_version: RoomVersion,
+        partial_state: bool,
     ) -> int:
         """Persists the events returned by a send_join
 
@@ -412,6 +413,7 @@ class FederationEventHandler:
             event
             room_version: The room version we expect this room to have, and
                 will raise if it doesn't match the version in the create event.
+            partial_state: True if the state omits non-critical membership events
 
         Returns:
             The stream ID after which all events have been persisted.
@@ -453,10 +455,14 @@ class FederationEventHandler:
         )
 
         # and now persist the join event itself.
-        logger.info("Peristing join-via-remote %s", event)
+        logger.info(
+            "Peristing join-via-remote %s (partial_state: %s)", event, partial_state
+        )
         with nested_logging_context(suffix=event.event_id):
             context = await self._state_handler.compute_event_context(
-                event, old_state=state
+                event,
+                old_state=state,
+                partial_state=partial_state,
             )
 
             context = await self._check_event_auth(origin, event, context)
@@ -698,6 +704,8 @@ class FederationEventHandler:
 
         try:
             state = await self._resolve_state_at_missing_prevs(origin, event)
+            # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
+            #   not return partial state
             await self._process_received_pdu(
                 origin, event, state=state, backfilled=backfilled
             )
@@ -1791,6 +1799,7 @@ class FederationEventHandler:
             prev_state_ids=prev_state_ids,
             prev_group=prev_group,
             delta_ids=state_updates,
+            partial_state=context.partial_state,
         )
 
     async def _run_push_actions_and_persist_event(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ce1fa3c78e..61cb133ef2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -992,6 +992,8 @@ class EventCreationHandler:
             and full_state_ids_at_event
             and builder.internal_metadata.is_historical()
         ):
+            # TODO(faster_joins): figure out how this works, and make sure that the
+            #   old state is complete.
             old_state = await self.store.get_events_as_list(full_state_ids_at_event)
             context = await self.state.compute_event_context(event, old_state=old_state)
         else:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fcc24ad129..6babd5963c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -258,7 +258,10 @@ class StateHandler:
         return await self.store.get_joined_hosts(room_id, entry)
 
     async def compute_event_context(
-        self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+        self,
+        event: EventBase,
+        old_state: Optional[Iterable[EventBase]] = None,
+        partial_state: bool = False,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
 
@@ -273,6 +276,8 @@ class StateHandler:
                 calculated from existing events. This is normally only specified
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
+            partial_state: True if `old_state` is partial and omits non-critical
+                membership events
         Returns:
             The event context.
         """
@@ -295,8 +300,28 @@ class StateHandler:
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
-            logger.debug("calling resolve_state_groups from compute_event_context")
 
+            # partial_state should not be set explicitly in this case:
+            # we work it out dynamically
+            assert not partial_state
+
+            # if any of the prev-events have partial state, so do we.
+            # (This is slightly racy - the prev-events might get fixed up before we use
+            # their states - but I don't think that really matters; it just means we
+            # might redundantly recalculate the state for this event later.)
+            prev_event_ids = event.prev_event_ids()
+            incomplete_prev_events = await self.store.get_partial_state_events(
+                prev_event_ids
+            )
+            if any(incomplete_prev_events.values()):
+                logger.debug(
+                    "New/incoming event %s refers to prev_events %s with partial state",
+                    event.event_id,
+                    [k for (k, v) in incomplete_prev_events.items() if v],
+                )
+                partial_state = True
+
+            logger.debug("calling resolve_state_groups from compute_event_context")
             entry = await self.resolve_state_groups_for_events(
                 event.room_id, event.prev_event_ids()
             )
@@ -342,6 +367,7 @@ class StateHandler:
                 prev_state_ids=state_ids_before_event,
                 prev_group=state_group_before_event_prev_group,
                 delta_ids=deltas_to_state_group_before_event,
+                partial_state=partial_state,
             )
 
         #
@@ -373,6 +399,7 @@ class StateHandler:
             prev_state_ids=state_ids_before_event,
             prev_group=state_group_before_event,
             delta_ids=delta_ids,
+            partial_state=partial_state,
         )
 
     @measure_func()
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 23fa089bca..ca2a9ba9d1 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2145,6 +2145,14 @@ class PersistEventsStore:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
+                # double-check that we don't have any events that claim to be outliers
+                # *and* have partial state (which is meaningless: we should have no
+                # state at all for an outlier)
+                if context.partial_state:
+                    raise ValueError(
+                        "Outlier event %s claims to have partial state", event.event_id
+                    )
+
                 continue
 
             # if the event was rejected, just give it the same state as its
@@ -2155,6 +2163,23 @@ class PersistEventsStore:
 
             state_groups[event.event_id] = context.state_group
 
+        # if we have partial state for these events, record the fact. (This happens
+        # here rather than in _store_event_txn because it also needs to happen when
+        # we de-outlier an event.)
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="partial_state_events",
+            keys=("room_id", "event_id"),
+            values=[
+                (
+                    event.room_id,
+                    event.event_id,
+                )
+                for event, ctx in events_and_contexts
+                if ctx.partial_state
+            ],
+        )
+
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_to_state_groups",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 2a255d1031..26784f755e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore):
             "get_event_id_for_timestamp_txn",
             get_event_id_for_timestamp_txn,
         )
+
+    @cachedList("is_partial_state_event", list_name="event_ids")
+    async def get_partial_state_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, bool]:
+        """Checks which of the given events have partial state"""
+        result = await self.db_pool.simple_select_many_batch(
+            table="partial_state_events",
+            column="event_id",
+            iterable=event_ids,
+            retcols=["event_id"],
+            desc="get_partial_state_events",
+        )
+        # convert the result to a dict, to make @cachedList work
+        partial = {r["event_id"] for r in result}
+        return {e_id: e_id in partial for e_id in event_ids}
+
+    @cached()
+    async def is_partial_state_event(self, event_id: str) -> bool:
+        """Checks if the given event has partial state"""
+        result = await self.db_pool.simple_select_one_onecol(
+            table="partial_state_events",
+            keyvalues={"event_id": event_id},
+            retcol="1",
+            allow_none=True,
+            desc="is_partial_state_event",
+        )
+        return result is not None
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0416df64ce..94068940b9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -20,6 +20,7 @@ from typing import (
     TYPE_CHECKING,
     Any,
     Awaitable,
+    Collection,
     Dict,
     List,
     Optional,
@@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             lock=False,
         )
 
+    async def store_partial_state_room(
+        self,
+        room_id: str,
+        servers: Collection[str],
+    ) -> None:
+        """Mark the given room as containing events with partial state
+
+        Args:
+            room_id: the ID of the room
+            servers: other servers known to be in the room
+        """
+        await self.db_pool.runInteraction(
+            "store_partial_state_room",
+            self._store_partial_state_room_txn,
+            room_id,
+            servers,
+        )
+
+    @staticmethod
+    def _store_partial_state_room_txn(
+        txn: LoggingTransaction, room_id: str, servers: Collection[str]
+    ) -> None:
+        DatabasePool.simple_insert_txn(
+            txn,
+            table="partial_state_rooms",
+            values={
+                "room_id": room_id,
+            },
+        )
+        DatabasePool.simple_insert_many_txn(
+            txn,
+            table="partial_state_rooms_servers",
+            keys=("room_id", "server_name"),
+            values=((room_id, s) for s in servers),
+        )
+
     async def maybe_store_room_on_outlier_membership(
         self, room_id: str, room_version: RoomVersion
     ) -> None:
diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql
new file mode 100644
index 0000000000..815c0cc390
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql
@@ -0,0 +1,41 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- rooms which we have done a partial-state-style join to
+CREATE TABLE IF NOT EXISTS partial_state_rooms (
+    room_id TEXT PRIMARY KEY,
+    FOREIGN KEY(room_id) REFERENCES rooms(room_id)
+);
+
+-- a list of remote servers we believe are in the room
+CREATE TABLE IF NOT EXISTS partial_state_rooms_servers (
+    room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
+    server_name TEXT NOT NULL,
+    UNIQUE(room_id, server_name)
+);
+
+-- a list of events with partial state. We can't store this in the `events` table
+-- itself, because `events` is meant to be append-only.
+CREATE TABLE IF NOT EXISTS partial_state_events (
+    -- the room_id is denormalised for efficient indexing (the canonical source is `events`)
+    room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
+    event_id TEXT NOT NULL REFERENCES events(event_id),
+    UNIQUE(event_id)
+);
+
+CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx
+     ON partial_state_events (room_id);
+
+
diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py
new file mode 100644
index 0000000000..a2ec4fc26e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py
@@ -0,0 +1,72 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This migration adds triggers to the partial_state_events tables to enforce uniqueness
+
+Triggers cannot be expressed in .sql files, so we have to use a separate file.
+"""
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    # complain if the room_id in partial_state_events doesn't match
+    # that in `events`. We already have a fk constraint which ensures that the event
+    # exists in `events`, so all we have to do is raise if there is a row with a
+    # matching stream_ordering but not a matching room_id.
+    if isinstance(database_engine, Sqlite3Engine):
+        cur.execute(
+            """
+            CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id
+            BEFORE INSERT ON partial_state_events
+            FOR EACH ROW
+            BEGIN
+                SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events')
+                WHERE EXISTS (
+                    SELECT 1 FROM events
+                    WHERE events.event_id = NEW.event_id
+                       AND events.room_id != NEW.room_id
+                );
+            END;
+            """
+        )
+    elif isinstance(database_engine, PostgresEngine):
+        cur.execute(
+            """
+            CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$
+            BEGIN
+                IF EXISTS (
+                    SELECT 1 FROM events
+                    WHERE events.event_id = NEW.event_id
+                       AND events.room_id != NEW.room_id
+                ) THEN
+                    RAISE EXCEPTION 'Incorrect room_id in partial_state_events';
+                END IF;
+                RETURN NEW;
+            END;
+            $BODY$ LANGUAGE plpgsql;
+            """
+        )
+
+        cur.execute(
+            """
+            CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events
+            FOR EACH ROW
+            EXECUTE PROCEDURE check_partial_state_events()
+            """
+        )
+    else:
+        raise NotImplementedError("Unknown database engine")
diff --git a/tests/test_state.py b/tests/test_state.py
index 90800421fb..e4baa69137 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional
+from typing import Collection, Dict, List, Optional
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -70,7 +70,7 @@ def create_event(
     return event
 
 
-class StateGroupStore:
+class _DummyStore:
     def __init__(self):
         self._event_to_state_group = {}
         self._group_to_state = {}
@@ -105,6 +105,11 @@ class StateGroupStore:
             if e_id in self._event_id_to_event
         }
 
+    async def get_partial_state_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, bool]:
+        return {e: False for e in event_ids}
+
     async def get_state_group_delta(self, name):
         return None, None
 
@@ -157,8 +162,8 @@ class Graph:
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.store = StateGroupStore()
-        storage = Mock(main=self.store, state=self.store)
+        self.dummy_store = _DummyStore()
+        storage = Mock(main=self.dummy_store, state=self.dummy_store)
         hs = Mock(
             spec_set=[
                 "config",
@@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase):
             ]
         )
         hs.config = default_config("tesths", True)
-        hs.get_datastores.return_value = Mock(main=self.store)
+        hs.get_datastores.return_value = Mock(main=self.dummy_store)
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
@@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store: dict[str, EventContext] = {}
 
@@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         ctx_c = context_store["C"]
@@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # C ends up winning the resolution between B and C
@@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # C ends up winning the resolution between C and D because bans win over other
@@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase):
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # B ends up winning the resolution between B and C because power levels
@@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id,
                 event.room_id,
                 None,
@@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id, group_name)
+        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
@@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id,
                 event.room_id,
                 None,
@@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id, group_name)
+        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
@@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        self.store.register_events(old_state_1)
-        self.store.register_events(old_state_2)
+        self.dummy_store.register_events(old_state_1)
+        self.dummy_store.register_events(old_state_2)
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        store = StateGroupStore()
+        store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=2),
         ]
 
-        store = StateGroupStore()
+        store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase):
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
         sg1 = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id_1,
                 event.room_id,
                 None,
@@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state_1},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id_1, sg1)
+        self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
 
         sg2 = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id_2,
                 event.room_id,
                 None,
@@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state_2},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id_2, sg2)
+        self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
 
         result = yield defer.ensureDeferred(self.state.compute_event_context(event))
         return result