summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/snapshot.py4
-rw-r--r--synapse/handlers/federation.py24
-rw-r--r--synapse/metrics/metric.py11
-rw-r--r--synapse/replication/slave/storage/events.py4
-rw-r--r--synapse/state.py56
-rw-r--r--synapse/storage/__init__.py1
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite3.py19
-rw-r--r--synapse/storage/events.py43
-rw-r--r--synapse/storage/room.py43
-rw-r--r--synapse/storage/schema/delta/47/state_group_seq.py37
-rw-r--r--synapse/storage/search.py110
-rw-r--r--synapse/storage/state.py196
-rw-r--r--synapse/util/caches/descriptors.py4
-rw-r--r--synapse/util/caches/expiringcache.py6
-rw-r--r--synapse/util/caches/lrucache.py28
-rw-r--r--tests/metrics/test_metric.py12
-rw-r--r--tests/replication/slave/storage/test_events.py4
-rw-r--r--tests/test_state.py154
19 files changed, 499 insertions, 263 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e9a732ff03..87e3fe7b97 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -25,7 +25,9 @@ class EventContext(object):
             The current state map excluding the current event.
             (type, state_key) -> event_id
 
-        state_group (int): state group id
+        state_group (int|None): state group id, if the state has been stored
+            as a state group. This is usually only None if e.g. the event is
+            an outlier.
         rejected (bool|str): A rejection reason if the event was rejected, else
             False
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index cba96111d1..46bcf8b081 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1831,8 +1831,8 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                self._update_context_for_auth_events(
-                    context, auth_events, event_key,
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
                 )
 
         if different_auth and not event.internal_metadata.is_outlier():
@@ -1913,8 +1913,8 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                self._update_context_for_auth_events(
-                    context, auth_events, event_key,
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
                 )
 
         try:
@@ -1923,11 +1923,15 @@ class FederationHandler(BaseHandler):
             logger.warn("Failed auth resolution for %r because %s", event, e)
             raise e
 
-    def _update_context_for_auth_events(self, context, auth_events,
+    @defer.inlineCallbacks
+    def _update_context_for_auth_events(self, event, context, auth_events,
                                         event_key):
-        """Update the state_ids in an event context after auth event resolution
+        """Update the state_ids in an event context after auth event resolution,
+        storing the changes as a new state group.
 
         Args:
+            event (Event): The event we're handling the context for
+
             context (synapse.events.snapshot.EventContext): event context
                 to be updated
 
@@ -1950,7 +1954,13 @@ class FederationHandler(BaseHandler):
         context.prev_state_ids.update({
             k: a.event_id for k, a in auth_events.iteritems()
         })
-        context.state_group = self.store.get_next_state_group()
+        context.state_group = yield self.store.store_state_group(
+            event.event_id,
+            event.room_id,
+            prev_group=context.prev_group,
+            delta_ids=context.delta_ids,
+            current_state_ids=context.current_state_ids,
+        )
 
     @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 1e783e5ff4..ff5aa8c0e1 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -193,7 +193,9 @@ class DistributionMetric(object):
 
 
 class CacheMetric(object):
-    __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
+    __slots__ = (
+        "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
+    )
 
     def __init__(self, name, size_callback, cache_name):
         self.name = name
@@ -201,6 +203,7 @@ class CacheMetric(object):
 
         self.hits = 0
         self.misses = 0
+        self.evicted_size = 0
 
         self.size_callback = size_callback
 
@@ -210,6 +213,9 @@ class CacheMetric(object):
     def inc_misses(self):
         self.misses += 1
 
+    def inc_evictions(self, size=1):
+        self.evicted_size += size
+
     def render(self):
         size = self.size_callback()
         hits = self.hits
@@ -219,6 +225,9 @@ class CacheMetric(object):
             """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
             """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
             """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+            """%s:evicted_size{name="%s"} %d""" % (
+                self.name, self.cache_name, self.evicted_size
+            ),
         ]
 
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29d7296b43..8acb5df0f3 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -19,7 +19,7 @@ from synapse.storage import DataStore
 from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.event_push_actions import EventPushActionsStore
 from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.state import StateGroupReadStore
+from synapse.storage.state import StateGroupWorkerStore
 from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from ._base import BaseSlavedStore
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
+class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
         super(SlavedEventStore, self).__init__(db_conn, hs)
diff --git a/synapse/state.py b/synapse/state.py
index 273f9911ca..cc93bbcb6b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -183,8 +183,15 @@ class StateHandler(object):
     def compute_event_context(self, event, old_state=None):
         """Build an EventContext structure for the event.
 
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
+
         Args:
             event (synapse.events.EventBase):
+            old_state (dict|None): The state at the event if it can't be
+                calculated from existing events. This is normally only specified
+                when receiving an event from federation where we don't have the
+                prev events for, e.g. when backfilling.
         Returns:
             synapse.events.snapshot.EventContext:
         """
@@ -208,15 +215,22 @@ class StateHandler(object):
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = self.store.get_next_state_group()
+
+            # We don't store state for outliers, so we don't generate a state
+            # froup for it.
+            context.state_group = None
+
             defer.returnValue(context)
 
         if old_state:
+            # We already have the state, so we don't need to calculate it.
+            # Let's just correctly fill out the context and create a
+            # new state group for it.
+
             context = EventContext()
             context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -229,6 +243,14 @@ class StateHandler(object):
             else:
                 context.current_state_ids = context.prev_state_ids
 
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=None,
+                delta_ids=None,
+                current_state_ids=context.current_state_ids,
+            )
+
             context.prev_state_events = []
             defer.returnValue(context)
 
@@ -242,7 +264,8 @@ class StateHandler(object):
         context = EventContext()
         context.prev_state_ids = curr_state
         if event.is_state():
-            context.state_group = self.store.get_next_state_group()
+            # If this is a state event then we need to create a new state
+            # group for the state after this event.
 
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
@@ -253,23 +276,42 @@ class StateHandler(object):
             context.current_state_ids[key] = event.event_id
 
             if entry.state_group:
+                # If the state at the event has a state group assigned then
+                # we can use that as the prev group
                 context.prev_group = entry.state_group
                 context.delta_ids = {
                     key: event.event_id
                 }
             elif entry.prev_group:
+                # If the state at the event only has a prev group, then we can
+                # use that as a prev group too.
                 context.prev_group = entry.prev_group
                 context.delta_ids = dict(entry.delta_ids)
                 context.delta_ids[key] = event.event_id
+
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=context.prev_group,
+                delta_ids=context.delta_ids,
+                current_state_ids=context.current_state_ids,
+            )
         else:
+            context.current_state_ids = context.prev_state_ids
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
             if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
+                entry.state_group = yield self.store.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=entry.prev_group,
+                    delta_ids=entry.delta_ids,
+                    current_state_ids=context.current_state_ids,
+                )
                 entry.state_id = entry.state_group
 
             context.state_group = entry.state_group
-            context.current_state_ids = context.prev_state_ids
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
 
         context.prev_state_events = []
         defer.returnValue(context)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d01d46338a..f8fbd02ceb 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
 
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        txn.execute("SELECT nextval('state_group_id_seq')")
+        return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
 from synapse.storage.prepare_database import prepare_database
 
 import struct
+import threading
 
 
 class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        # The current max state_group, or None if we haven't looked
+        # in the DB yet.
+        self._current_state_group_id = None
+        self._current_state_group_id_lock = threading.Lock()
+
     def check_database(self, txn):
         pass
 
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
     def lock_table(self, txn, table):
         return
 
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._current_state_group_id_lock:
+            if self._current_state_group_id is None:
+                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+                self._current_state_group_id = txn.fetchone()[0]
+
+            self._current_state_group_id += 1
+            return self._current_state_group_id
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index dd28c2efe3..af56f1ee57 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -342,8 +342,20 @@ class EventsStore(SQLBaseStore):
 
                 # NB: Assumes that we are only persisting events for one room
                 # at a time.
+
+                # map room_id->list[event_ids] giving the new forward
+                # extremities in each room
                 new_forward_extremeties = {}
+
+                # map room_id->(type,state_key)->event_id tracking the full
+                # state in each room after adding these events
                 current_state_for_room = {}
+
+                # map room_id->(to_delete, to_insert) where each entry is
+                # a map (type,key)->event_id giving the state delta in each
+                # room
+                state_delta_for_room = {}
+
                 if not backfilled:
                     with Measure(self._clock, "_calculate_state_and_extrem"):
                         # Work out the new "current state" for each room.
@@ -393,11 +405,12 @@ class EventsStore(SQLBaseStore):
                                 ev_ctx_rm, new_latest_event_ids,
                             )
                             if current_state is not None:
+                                current_state_for_room[room_id] = current_state
                                 delta = yield self._calculate_state_delta(
                                     room_id, current_state,
                                 )
                                 if delta is not None:
-                                    current_state_for_room[room_id] = delta
+                                    state_delta_for_room[room_id] = delta
 
                 yield self.runInteraction(
                     "persist_events",
@@ -405,7 +418,7 @@ class EventsStore(SQLBaseStore):
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
-                    current_state_for_room=current_state_for_room,
+                    state_delta_for_room=state_delta_for_room,
                     new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
@@ -422,7 +435,7 @@ class EventsStore(SQLBaseStore):
 
                     event_counter.inc(event.type, origin_type, origin_entity)
 
-                for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+                for room_id, new_state in current_state_for_room.iteritems():
                     self.get_current_state_ids.prefill(
                         (room_id, ), new_state
                     )
@@ -586,10 +599,10 @@ class EventsStore(SQLBaseStore):
         Assumes that we are only persisting events for one room at a time.
 
         Returns:
-            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
+            2-tuple (to_delete, to_insert) where both are state dicts,
             i.e. (type, state_key) -> event_id. `to_delete` are the entries to
             first be deleted from current_state_events, `to_insert` are entries
-            to insert. `new_state` is the full set of state.
+            to insert.
         """
         existing_state = yield self.get_current_state_ids(room_id)
 
@@ -610,7 +623,7 @@ class EventsStore(SQLBaseStore):
             if ev_id in events_to_insert
         }
 
-        defer.returnValue((to_delete, to_insert, current_state))
+        defer.returnValue((to_delete, to_insert))
 
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
@@ -670,7 +683,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, current_state_for_room={},
+                            delete_existing=False, state_delta_for_room={},
                             new_forward_extremeties={}):
         """Insert some number of room events into the necessary database tables.
 
@@ -686,7 +699,7 @@ class EventsStore(SQLBaseStore):
             delete_existing (bool): True to purge existing table rows for the
                 events from the database. This is useful when retrying due to
                 IntegrityError.
-            current_state_for_room (dict[str, (list[str], list[str])]):
+            state_delta_for_room (dict[str, (list[str], list[str])]):
                 The current-state delta for each room. For each room, a tuple
                 (to_delete, to_insert), being a list of event ids to be removed
                 from the current state, and a list of event ids to be added to
@@ -698,7 +711,7 @@ class EventsStore(SQLBaseStore):
         """
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+        self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
 
         self._update_forward_extremities_txn(
             txn,
@@ -742,9 +755,8 @@ class EventsStore(SQLBaseStore):
             events_and_contexts=events_and_contexts,
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
+        # Insert into event_to_state_groups.
+        self._store_event_state_mappings_txn(txn, events_and_contexts)
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -764,7 +776,7 @@ class EventsStore(SQLBaseStore):
 
     def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
         for room_id, current_state_tuple in state_delta_by_room.iteritems():
-                to_delete, to_insert, _ = current_state_tuple
+                to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
                     [(ev_id,) for ev_id in to_delete.itervalues()],
@@ -979,10 +991,9 @@ class EventsStore(SQLBaseStore):
                 # an outlier in the database. We now have some state at that
                 # so we need to update the state_groups table with that state.
 
-                # insert into the state_group, state_groups_state and
-                # event_to_state_groups tables.
+                # insert into event_to_state_groups.
                 try:
-                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                    self._store_event_state_mappings_txn(txn, ((event, context),))
                 except Exception:
                     logger.exception("")
                     raise
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index cf2c4dae39..fff6652e05 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,11 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.storage.search import SearchStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
-
 import collections
 import logging
 import ujson as json
@@ -40,7 +38,7 @@ RatelimitOverride = collections.namedtuple(
 )
 
 
-class RoomStore(SQLBaseStore):
+class RoomStore(SearchStore):
 
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
@@ -263,8 +261,8 @@ class RoomStore(SQLBaseStore):
                 },
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"],
             )
 
     def _store_room_name_txn(self, txn, event):
@@ -279,14 +277,14 @@ class RoomStore(SQLBaseStore):
                 }
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"],
             )
 
     def _store_room_message_txn(self, txn, event):
         if hasattr(event, "content") and "body" in event.content:
-            self._store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"],
             )
 
     def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +306,6 @@ class RoomStore(SQLBaseStore):
                 event.content[key]
             ))
 
-    def _store_event_search_txn(self, txn, event, key, value):
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-            txn.execute(
-                sql,
-                (
-                    event.event_id, event.room_id, key, value,
-                    event.internal_metadata.stream_ordering,
-                    event.origin_server_ts,
-                )
-            )
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            txn.execute(sql, (event.event_id, event.room_id, key, value,))
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
     def add_event_report(self, room_id, event_id, user_id, reason, content,
                          received_ts):
         next_id = self._event_reports_id_gen.get_next()
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        # if we already have some state groups, we want to start making new
+        # ones with a higher id.
+        cur.execute("SELECT max(id) FROM state_groups")
+        row = cur.fetchone()
+
+        if row[0] is None:
+            start_val = 1
+        else:
+            start_val = row[0] + 1
+
+        cur.execute(
+            "CREATE SEQUENCE state_group_id_seq START WITH %s",
+            (start_val, ),
+        )
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 479b04c636..f1ac9ba0fd 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,19 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from collections import namedtuple
+import logging
+import re
+import ujson as json
+
 from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import logging
-import re
-import ujson as json
-
 
 logger = logging.getLogger(__name__)
 
+SearchEntry = namedtuple('SearchEntry', [
+    'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+    'origin_server_ts',
+])
+
 
 class SearchStore(BackgroundUpdateStore):
 
@@ -49,16 +55,17 @@ class SearchStore(BackgroundUpdateStore):
 
     @defer.inlineCallbacks
     def _background_reindex_search(self, progress, batch_size):
+        # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
         def reindex_search_txn(txn):
             sql = (
-                "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+                "SELECT stream_ordering, event_id, room_id, type, content, "
+                " origin_server_ts FROM events"
                 " WHERE ? <= stream_ordering AND stream_ordering < ?"
                 " AND (%s)"
                 " ORDER BY stream_ordering DESC"
@@ -67,6 +74,10 @@ class SearchStore(BackgroundUpdateStore):
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
+            # we could stream straight from the results into
+            # store_search_entries_txn with a generator function, but that
+            # would mean having two cursors open on the database at once.
+            # Instead we just build a list of results.
             rows = self.cursor_to_dict(txn)
             if not rows:
                 return 0
@@ -79,6 +90,8 @@ class SearchStore(BackgroundUpdateStore):
                     event_id = row["event_id"]
                     room_id = row["room_id"]
                     etype = row["type"]
+                    stream_ordering = row["stream_ordering"]
+                    origin_server_ts = row["origin_server_ts"]
                     try:
                         content = json.loads(row["content"])
                     except Exception:
@@ -93,6 +106,8 @@ class SearchStore(BackgroundUpdateStore):
                     elif etype == "m.room.name":
                         key = "content.name"
                         value = content["name"]
+                    else:
+                        raise Exception("unexpected event type %s" % etype)
                 except (KeyError, AttributeError):
                     # If the event is missing a necessary field then
                     # skip over it.
@@ -103,25 +118,16 @@ class SearchStore(BackgroundUpdateStore):
                     # then skip over it
                     continue
 
-                event_search_rows.append((event_id, room_id, key, value))
+                event_search_rows.append(SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event_id,
+                    room_id=room_id,
+                    stream_ordering=stream_ordering,
+                    origin_server_ts=origin_server_ts,
+                ))
 
-            if isinstance(self.database_engine, PostgresEngine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, vector)"
-                    " VALUES (?,?,?,to_tsvector('english', ?))"
-                )
-            elif isinstance(self.database_engine, Sqlite3Engine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, value)"
-                    " VALUES (?,?,?,?)"
-                )
-            else:
-                # This should be unreachable.
-                raise Exception("Unrecognized database engine")
-
-            for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
-                clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            self.store_search_entries_txn(txn, event_search_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -242,6 +248,62 @@ class SearchStore(BackgroundUpdateStore):
 
         defer.returnValue(num_rows)
 
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
+
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
+        """
+        self.store_search_entries_txn(
+            txn,
+            (SearchEntry(
+                key=key,
+                value=value,
+                event_id=event.event_id,
+                room_id=event.room_id,
+                stream_ordering=event.internal_metadata.stream_ordering,
+                origin_server_ts=event.origin_server_ts,
+            ),),
+        )
+
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+                entry.stream_ordering, entry.origin_server_ts,
+            ) for entry in entries)
+
+            txn.executemany(sql, args)
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+            ) for entry in entries)
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 360e3e4355..adb48df73e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
         return len(self.delta_ids) if self.delta_ids else 0
 
 
-class StateGroupReadStore(SQLBaseStore):
-    """The read-only parts of StateGroupStore
-
-    None of these functions write to the state tables, so are suitable for
-    including in the SlavedStores.
+class StateGroupWorkerStore(SQLBaseStore):
+    """The parts of StateGroupStore that can be called from workers.
     """
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
     def __init__(self, db_conn, hs):
-        super(StateGroupReadStore, self).__init__(db_conn, hs)
+        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
         self._state_group_cache = DictionaryCache(
             "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@@ -549,116 +546,66 @@ class StateGroupReadStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        """Store a new set of state, returning a newly assigned state group.
 
-class StateStore(StateGroupReadStore, BackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
-
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
-
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
-
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
-    """
-
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME,
-            self._background_index_state,
-        )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
-        )
-
-    def _have_persisted_state_group_txn(self, txn, state_group):
-        txn.execute(
-            "SELECT count(*) FROM state_groups WHERE id = ?",
-            (state_group,)
-        )
-        row = txn.fetchone()
-        return row and row[0]
-
-    def _store_mult_state_groups_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
 
-            if context.current_state_ids is None:
+        Returns:
+            Deferred[int]: The state group ID
+        """
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
                 # AFAIK, this can never happen
-                logger.error(
-                    "Non-outlier event %s had current_state_ids==None",
-                    event.event_id)
-                continue
+                raise Exception("current_state_ids cannot be None")
 
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.prev_group
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-            if self._have_persisted_state_group_txn(txn, context.state_group):
-                continue
+            state_group = self.database_engine.get_next_state_group_id(txn)
 
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": context.state_group,
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
+                    "id": state_group,
+                    "room_id": room_id,
+                    "event_id": event_id,
                 },
             )
 
             # We persist as a delta if we can, while also ensuring the chain
             # of deltas isn't tooo long, as otherwise read performance degrades.
-            if context.prev_group:
+            if prev_group:
                 is_in_db = self._simple_select_one_onecol_txn(
                     txn,
                     table="state_groups",
-                    keyvalues={"id": context.prev_group},
+                    keyvalues={"id": prev_group},
                     retcol="id",
                     allow_none=True,
                 )
                 if not is_in_db:
                     raise Exception(
                         "Trying to persist state with unpersisted prev_group: %r"
-                        % (context.prev_group,)
+                        % (prev_group,)
                     )
 
                 potential_hops = self._count_state_group_hops_txn(
-                    txn, context.prev_group
+                    txn, prev_group
                 )
-            if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                 self._simple_insert_txn(
                     txn,
                     table="state_group_edges",
                     values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
+                        "state_group": state_group,
+                        "prev_state_group": prev_group,
                     },
                 )
 
@@ -667,13 +614,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
                     table="state_groups_state",
                     values=[
                         {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
+                            "state_group": state_group,
+                            "room_id": room_id,
                             "type": key[0],
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.delta_ids.iteritems()
+                        for key, state_id in delta_ids.iteritems()
                     ],
                 )
             else:
@@ -682,13 +629,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
                     table="state_groups_state",
                     values=[
                         {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
+                            "state_group": state_group,
+                            "room_id": room_id,
                             "type": key[0],
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.current_state_ids.iteritems()
+                        for key, state_id in current_state_ids.iteritems()
                     ],
                 )
 
@@ -699,11 +646,71 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
             txn.call_after(
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
-                key=context.state_group,
-                value=dict(context.current_state_ids),
+                key=state_group,
+                value=dict(current_state_ids),
                 full=True,
             )
 
+            return state_group
+
+        return self.runInteraction("store_state_group", _store_state_group_txn)
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+    def __init__(self, db_conn, hs):
+        super(StateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME,
+            self._background_index_state,
+        )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+
+    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
         self._simple_insert_many_txn(
             txn,
             table="event_to_state_groups",
@@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
 
             return count
 
-    def get_next_state_group(self):
-        return self._state_groups_id_gen.get_next()
-
     @defer.inlineCallbacks
     def _background_deduplicate_state(self, progress, batch_size):
         """This background update will slowly deduplicate state by reencoding
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..bf3a66eae4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -75,6 +75,7 @@ class Cache(object):
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..f088dd430e 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f85455a5af..39bde6e3f8 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 0',
             'cache:total{name="cache_name"} 0',
             'cache:size{name="cache_name"} 0',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
 
         metric.inc_misses()
@@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 0',
             'cache:total{name="cache_name"} 1',
             'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 0',
         ])
 
         metric.inc_hits()
@@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase):
             'cache:hits{name="cache_name"} 1',
             'cache:total{name="cache_name"} 2',
             'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 0',
+        ])
+
+        metric.inc_evictions(2)
+
+        self.assertEquals(metric.render(), [
+            'cache:hits{name="cache_name"} 1',
+            'cache:total{name="cache_name"} 2',
+            'cache:size{name="cache_name"} 1',
+            'cache:evicted_size{name="cache_name"} 2',
         ])
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 105e1228bb..f430cce931 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             context = EventContext()
             context.current_state_ids = state_ids
             context.prev_state_ids = state_ids
-        elif not backfill:
+        else:
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
-        else:
-            context = EventContext()
 
         context.push_actions = push_actions
 
diff --git a/tests/test_state.py b/tests/test_state.py
index d16e1b3b8b..a5c5e55951 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -80,14 +80,14 @@ class StateGroupStore(object):
 
         return defer.succeed(groups)
 
-    def store_state_groups(self, event, context):
-        if context.current_state_ids is None:
-            return
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        state_group = self._next_group
+        self._next_group += 1
 
-        state_events = dict(context.current_state_ids)
+        self._group_to_state[state_group] = dict(current_state_ids)
 
-        self._group_to_state[context.state_group] = state_events
-        self._event_to_state_group[event.event_id] = context.state_group
+        return state_group
 
     def get_events(self, event_ids, **kwargs):
         return {
@@ -95,10 +95,19 @@ class StateGroupStore(object):
             if e_id in self._event_id_to_event
         }
 
+    def get_state_group_delta(self, name):
+        return (None, None)
+
     def register_events(self, events):
         for e in events:
             self._event_id_to_event[e.event_id] = e
 
+    def register_event_context(self, event, context):
+        self._event_to_state_group[event.event_id] = context.state_group
+
+    def register_event_id_state_group(self, event_id, state_group):
+        self._event_to_state_group[event_id] = state_group
+
 
 class DictObj(dict):
     def __init__(self, **kwargs):
@@ -137,15 +146,7 @@ class Graph(object):
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.store = Mock(
-            spec_set=[
-                "get_state_groups_ids",
-                "add_event_hashes",
-                "get_events",
-                "get_next_state_group",
-                "get_state_group_delta",
-            ]
-        )
+        self.store = StateGroupStore()
         hs = Mock(spec_set=[
             "get_datastore", "get_auth", "get_state_handler", "get_clock",
             "get_state_resolution_handler",
@@ -156,9 +157,6 @@ class StateTestCase(unittest.TestCase):
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
 
-        self.store.get_next_state_group.side_effect = Mock
-        self.store.get_state_group_delta.return_value = (None, None)
-
         self.state = StateHandler(hs)
         self.event_id = 0
 
@@ -197,14 +195,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertEqual(2, len(context_store["D"].prev_state_ids))
@@ -249,16 +246,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -315,16 +309,13 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -398,16 +389,13 @@ class StateTestCase(unittest.TestCase):
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
 
-        store = StateGroupStore()
-        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
-        self.store.get_events = store.get_events
-        store.register_events(graph.walk())
+        self.store.register_events(graph.walk())
 
         context_store = {}
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
-            store.store_state_groups(event, context)
+            self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         self.assertSetEqual(
@@ -467,7 +455,11 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
-        event = create_event(type="test_message", name="event")
+        prev_event_id = "prev_event_id"
+        event = create_event(
+            type="test_message", name="event2",
+            prev_events=[(prev_event_id, {})],
+        )
 
         old_state = [
             create_event(type="test1", state_key="1"),
@@ -475,11 +467,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = "group_name_1"
-
-        self.store.get_state_groups_ids.return_value = {
-            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
-        }
+        group_name = self.store.store_state_group(
+            prev_event_id, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state},
+        )
+        self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
@@ -492,7 +484,11 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_trivial_annotate_state(self):
-        event = create_event(type="state", state_key="", name="event")
+        prev_event_id = "prev_event_id"
+        event = create_event(
+            type="state", state_key="", name="event2",
+            prev_events=[(prev_event_id, {})],
+        )
 
         old_state = [
             create_event(type="test1", state_key="1"),
@@ -500,11 +496,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = "group_name_1"
-
-        self.store.get_state_groups_ids.return_value = {
-            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
-        }
+        group_name = self.store.store_state_group(
+            prev_event_id, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state},
+        )
+        self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
@@ -517,7 +513,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
-        event = create_event(type="test_message", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test_message", name="event3",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         creation = create_event(
             type=EventTypes.Create, state_key=""
@@ -537,12 +538,12 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        store = StateGroupStore()
-        store.register_events(old_state_1)
-        store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.store.register_events(old_state_1)
+        self.store.register_events(old_state_2)
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(len(context.current_state_ids), 6)
 
@@ -550,7 +551,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
-        event = create_event(type="test4", state_key="", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test4", state_key="", name="event",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         creation = create_event(
             type=EventTypes.Create, state_key=""
@@ -575,7 +581,9 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
         self.store.get_events = store.get_events
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(len(context.current_state_ids), 6)
 
@@ -583,7 +591,12 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_standard_depth_conflict(self):
-        event = create_event(type="test4", name="event")
+        prev_event_id1 = "event_id1"
+        prev_event_id2 = "event_id2"
+        event = create_event(
+            type="test4", name="event",
+            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+        )
 
         member_event = create_event(
             type=EventTypes.Member,
@@ -615,7 +628,9 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
         self.store.get_events = store.get_events
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(
             old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@@ -639,19 +654,26 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_1)
         store.register_events(old_state_2)
 
-        context = yield self._get_context(event, old_state_1, old_state_2)
+        context = yield self._get_context(
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+        )
 
         self.assertEqual(
             old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
         )
 
-    def _get_context(self, event, old_state_1, old_state_2):
-        group_name_1 = "group_name_1"
-        group_name_2 = "group_name_2"
+    def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
+                     old_state_2):
+        sg1 = self.store.store_state_group(
+            prev_event_id_1, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state_1},
+        )
+        self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        self.store.get_state_groups_ids.return_value = {
-            group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
-            group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
-        }
+        sg2 = self.store.store_state_group(
+            prev_event_id_2, event.room_id, None, None,
+            {(e.type, e.state_key): e.event_id for e in old_state_2},
+        )
+        self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
         return self.state.compute_event_context(event)