summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py21
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/data_stores/__init__.py12
-rw-r--r--synapse/storage/data_stores/main/__init__.py9
-rw-r--r--synapse/storage/data_stores/main/devices.py109
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py24
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py2
-rw-r--r--synapse/storage/data_stores/main/events.py738
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py2
-rw-r--r--synapse/storage/data_stores/main/group_server.py4
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py2
-rw-r--r--synapse/storage/data_stores/main/push_rule.py2
-rw-r--r--synapse/storage/data_stores/main/pusher.py2
-rw-r--r--synapse/storage/data_stores/main/registration.py2
-rw-r--r--synapse/storage/data_stores/main/roommember.py2
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite42
-rw-r--r--synapse/storage/data_stores/main/search.py4
-rw-r--r--synapse/storage/data_stores/main/state.py20
-rw-r--r--synapse/storage/data_stores/main/stats.py4
-rw-r--r--synapse/storage/persist_events.py649
-rw-r--r--synapse/storage/state.py233
-rw-r--r--synapse/storage/util/id_generators.py2
22 files changed, 1188 insertions, 703 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a249ecd219..0a1a8cc1e5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,7 +27,26 @@ data stores associated with them (e.g. the schema version tables), which are
 stored in `synapse.storage.schema`.
 """
 
-from synapse.storage.data_stores.main import DataStore  # noqa: F401
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main import DataStore
+from synapse.storage.persist_events import EventsPersistenceStorage
+from synapse.storage.state import StateGroupStorage
+
+__all__ = ["DataStores", "DataStore"]
+
+
+class Storage(object):
+    """The high level interfaces for talking to various storage layers.
+    """
+
+    def __init__(self, hs, stores: DataStores):
+        # We include the main data store here mainly so that we don't have to
+        # rewrite all the existing code to split it into high vs low level
+        # interfaces.
+        self.main = stores.main
+
+        self.persistence = EventsPersistenceStorage(hs, stores)
+        self.state = StateGroupStorage(hs, stores)
 
 
 def are_all_users_on_domain(txn, database_engine, domain):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f5906fcd54..1a2b7ebe25 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -494,7 +494,7 @@ class SQLBaseStore(object):
         exception_callbacks = []
 
         if LoggingContext.current_context() == LoggingContext.sentinel:
-            logger.warn("Starting db txn '%s' from sentinel context", desc)
+            logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
             result = yield self.runWithConnection(
@@ -532,7 +532,7 @@ class SQLBaseStore(object):
         """
         parent_context = LoggingContext.current_context()
         if parent_context == LoggingContext.sentinel:
-            logger.warn(
+            logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
@@ -719,7 +719,7 @@ class SQLBaseStore(object):
                     raise
 
                 # presumably we raced with another transaction: let's retry.
-                logger.warn(
+                logger.warning(
                     "IntegrityError when upserting into %s; retrying: %s", table, e
                 )
 
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index 56094078ed..cb184a98cc 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -12,3 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+
+class DataStores(object):
+    """The various data stores.
+
+    These are low level interfaces to physical databases.
+    """
+
+    def __init__(self, main_store, db_conn, hs):
+        # Note we pass in the main store here as workers use a different main
+        # store.
+        self.main = main_store
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index b185ba0b3e..10c940df1e 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -139,7 +139,10 @@ class DataStore(
             db_conn, "public_room_list_stream", "stream_id"
         )
         self._device_list_id_gen = StreamIdGenerator(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[("user_signature_stream", "stream_id")],
         )
         self._cross_signing_id_gen = StreamIdGenerator(
             db_conn, "e2e_cross_signing_keys", "stream_id"
@@ -317,7 +320,7 @@ class DataStore(
             ) u
         """
         txn.execute(sql, (time_from,))
-        count, = txn.fetchone()
+        (count,) = txn.fetchone()
         return count
 
     def count_r30_users(self):
@@ -396,7 +399,7 @@ class DataStore(
 
             txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
 
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             results["all"] = count
 
             return results
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..71f62036c0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
     make_in_list_sql_clause,
 )
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @trace
     @defer.inlineCallbacks
-    def get_devices_by_remote(self, destination, from_stream_id, limit):
-        """Get stream of updates to send to remote servers
+    def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+        """Get a stream of device updates to send to the given remote server.
 
+        Args:
+            destination (str): The host the device updates are intended for
+            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            limit (int): Maximum number of device updates to return
         Returns:
-            Deferred[tuple[int, list[dict]]]:
+            Deferred[tuple[int, list[tuple[string,dict]]]]:
                 current stream id (ie, the stream id of the last update included in the
-                response), and the list of updates
+                response), and the list of updates, where each update is a pair of EDU
+                type and EDU contents
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
         # stream_id; the rationale being that such a large device list update
         # is likely an error.
         updates = yield self.runInteraction(
-            "get_devices_by_remote",
-            self._get_devices_by_remote_txn,
+            "get_device_updates_by_remote",
+            self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
         if not updates:
             return now_stream_id, []
 
+        # get the cross-signing keys of the users in the list, so that we can
+        # determine which of the device changes were cross-signing keys
+        users = set(r[0] for r in updates)
+        master_key_by_user = {}
+        self_signing_key_by_user = {}
+        for user in users:
+            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                # verify_key is a VerifyKey from signedjson, which uses
+                # .version to denote the portion of the key ID after the
+                # algorithm and colon, which is the device ID
+                master_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "device_id": verify_key.version,
+                }
+
+            cross_signing_key = yield self.get_e2e_cross_signing_key(
+                user, "self_signing"
+            )
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                self_signing_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "device_id": verify_key.version,
+                }
+
         # if we have exceeded the limit, we need to exclude any results with the
         # same stream_id as the last row.
         if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
         # context which created the Edu.
 
         query_map = {}
-        for update in updates:
-            if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+        cross_signing_keys_by_user = {}
+        for user_id, device_id, update_stream_id, update_context in updates:
+            if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
                 # Stop processing updates
                 break
 
-            key = (update[0], update[1])
-
-            update_context = update[3]
-            update_stream_id = update[2]
+            if (
+                user_id in master_key_by_user
+                and device_id == master_key_by_user[user_id]["device_id"]
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["master_key"] = master_key_by_user[user_id]["key_info"]
+            elif (
+                user_id in self_signing_key_by_user
+                and device_id == self_signing_key_by_user[user_id]["device_id"]
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["self_signing_key"] = self_signing_key_by_user[user_id][
+                    "key_info"
+                ]
+            else:
+                key = (user_id, device_id)
 
-            previous_update_stream_id, _ = query_map.get(key, (0, None))
+                previous_update_stream_id, _ = query_map.get(key, (0, None))
 
-            if update_stream_id > previous_update_stream_id:
-                query_map[key] = (update_stream_id, update_context)
+                if update_stream_id > previous_update_stream_id:
+                    query_map[key] = (update_stream_id, update_context)
 
         # If we didn't find any updates with a stream_id lower than the cutoff, it
         # means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
         # devices, in which case E2E isn't going to work well anyway. We'll just
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
-        if not query_map:
+        if not query_map and not cross_signing_keys_by_user:
             return stream_id_cutoff, []
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
 
+        # add the updated cross-signing keys to the results list
+        for user_id, result in iteritems(cross_signing_keys_by_user):
+            result["user_id"] = user_id
+            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            results.append(("org.matrix.signing_key_update", result))
+
         return now_stream_id, results
 
-    def _get_devices_by_remote_txn(
+    def _get_device_updates_by_remote_txn(
         self, txn, destination, from_stream_id, now_stream_id, limit
     ):
         """Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             List: List of device updates
         """
+        # get the list of device updates that need to be sent
         sql = """
             SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
             List[Dict]: List of objects representing an device update EDU
 
         """
-        devices = yield self.runInteraction(
-            "_get_e2e_device_keys_txn",
-            self._get_e2e_device_keys_txn,
-            query_map.keys(),
-            include_all_devices=True,
-            include_deleted_devices=True,
+        devices = (
+            yield self.runInteraction(
+                "_get_e2e_device_keys_txn",
+                self._get_e2e_device_keys_txn,
+                query_map.keys(),
+                include_all_devices=True,
+                include_deleted_devices=True,
+            )
+            if query_map
+            else {}
         )
 
         results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 else:
                     result["deleted"] = True
 
-                results.append(result)
+                results.append(("m.device_list_update", result))
 
         return results
 
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index a0bc6f2d18..073412a78d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             from_user_id,
         )
 
+    def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+        """Return a list of changes from the user signature stream to notify remotes.
+        Note that the user signature stream represents when a user signs their
+        device with their user-signing key, which is not published to other
+        users or servers, so no `destination` is needed in the returned
+        list. However, this is needed to poke workers.
+
+        Args:
+            from_key (int): the stream ID to start at (exclusive)
+            to_key (int): the stream ID to end at (inclusive)
+
+        Returns:
+            Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+        """
+        sql = """
+            SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+            FROM user_signature_stream
+            WHERE ? < stream_id AND stream_id <= ?
+            GROUP BY user_id
+        """
+        return self._execute(
+            "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+        )
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 22025effbc..04ce21ac66 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
         stream_row = txn.fetchone()
         if stream_row:
-            offset_stream_ordering, = stream_row
+            (offset_stream_ordering,) = stream_row
             rotate_to_stream_ordering = min(
                 self.stream_ordering_day_ago, offset_stream_ordering
             )
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 1045c7fa2e..301f8ea128 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,14 +17,14 @@
 
 import itertools
 import logging
-from collections import Counter as c_counter, OrderedDict, deque, namedtuple
+from collections import Counter as c_counter, OrderedDict, namedtuple
 from functools import wraps
 
 from six import iteritems, text_type
 from six.moves import range
 
 from canonicaljson import json
-from prometheus_client import Counter, Histogram
+from prometheus_client import Counter
 
 from twisted.internet import defer
 
@@ -34,11 +34,9 @@ from synapse.api.errors import SynapseError
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.events.utils import prune_event_dict
-from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.logging.utils import log_function
 from synapse.metrics import BucketCollector
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
 from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
@@ -46,10 +44,8 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util import batch_iter
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
 
@@ -60,37 +56,6 @@ event_counter = Counter(
     ["type", "origin_type", "origin_entity"],
 )
 
-# The number of times we are recalculating the current state
-state_delta_counter = Counter("synapse_storage_events_state_delta", "")
-
-# The number of times we are recalculating state when there is only a
-# single forward extremity
-state_delta_single_event_counter = Counter(
-    "synapse_storage_events_state_delta_single_event", ""
-)
-
-# The number of times we are reculating state when we could have resonably
-# calculated the delta when we calculated the state for an event we were
-# persisting.
-state_delta_reuse_delta_counter = Counter(
-    "synapse_storage_events_state_delta_reuse_delta", ""
-)
-
-# The number of forward extremities for each new event.
-forward_extremities_counter = Histogram(
-    "synapse_storage_events_forward_extremities_persisted",
-    "Number of forward extremities for each new event",
-    buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
-# The number of stale forward extremities for each new event. Stale extremities
-# are those that were in the previous set of extremities as well as the new.
-stale_forward_extremities_counter = Histogram(
-    "synapse_storage_events_stale_forward_extremities_persisted",
-    "Number of unchanged forward extremities for each new event",
-    buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
 
 def encode_json(json_object):
     """
@@ -102,110 +67,6 @@ def encode_json(json_object):
     return out
 
 
-class _EventPeristenceQueue(object):
-    """Queues up events so that they can be persisted in bulk with only one
-    concurrent transaction per room.
-    """
-
-    _EventPersistQueueItem = namedtuple(
-        "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
-    )
-
-    def __init__(self):
-        self._event_persist_queues = {}
-        self._currently_persisting_rooms = set()
-
-    def add_to_queue(self, room_id, events_and_contexts, backfilled):
-        """Add events to the queue, with the given persist_event options.
-
-        NB: due to the normal usage pattern of this method, it does *not*
-        follow the synapse logcontext rules, and leaves the logcontext in
-        place whether or not the returned deferred is ready.
-
-        Args:
-            room_id (str):
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
-
-        Returns:
-            defer.Deferred: a deferred which will resolve once the events are
-                persisted. Runs its callbacks *without* a logcontext.
-        """
-        queue = self._event_persist_queues.setdefault(room_id, deque())
-        if queue:
-            # if the last item in the queue has the same `backfilled` setting,
-            # we can just add these new events to that item.
-            end_item = queue[-1]
-            if end_item.backfilled == backfilled:
-                end_item.events_and_contexts.extend(events_and_contexts)
-                return end_item.deferred.observe()
-
-        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-
-        queue.append(
-            self._EventPersistQueueItem(
-                events_and_contexts=events_and_contexts,
-                backfilled=backfilled,
-                deferred=deferred,
-            )
-        )
-
-        return deferred.observe()
-
-    def handle_queue(self, room_id, per_item_callback):
-        """Attempts to handle the queue for a room if not already being handled.
-
-        The given callback will be invoked with for each item in the queue,
-        of type _EventPersistQueueItem. The per_item_callback will continuously
-        be called with new items, unless the queue becomnes empty. The return
-        value of the function will be given to the deferreds waiting on the item,
-        exceptions will be passed to the deferreds as well.
-
-        This function should therefore be called whenever anything is added
-        to the queue.
-
-        If another callback is currently handling the queue then it will not be
-        invoked.
-        """
-
-        if room_id in self._currently_persisting_rooms:
-            return
-
-        self._currently_persisting_rooms.add(room_id)
-
-        @defer.inlineCallbacks
-        def handle_queue_loop():
-            try:
-                queue = self._get_drainining_queue(room_id)
-                for item in queue:
-                    try:
-                        ret = yield per_item_callback(item)
-                    except Exception:
-                        with PreserveLoggingContext():
-                            item.deferred.errback()
-                    else:
-                        with PreserveLoggingContext():
-                            item.deferred.callback(ret)
-            finally:
-                queue = self._event_persist_queues.pop(room_id, None)
-                if queue:
-                    self._event_persist_queues[room_id] = queue
-                self._currently_persisting_rooms.discard(room_id)
-
-        # set handle_queue_loop off in the background
-        run_as_background_process("persist_events", handle_queue_loop)
-
-    def _get_drainining_queue(self, room_id):
-        queue = self._event_persist_queues.setdefault(room_id, deque())
-
-        try:
-            while True:
-                yield queue.popleft()
-        except IndexError:
-            # Queue has been drained.
-            pass
-
-
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
@@ -221,7 +82,7 @@ def _retry_on_integrity_error(func):
     @defer.inlineCallbacks
     def f(self, *args, **kwargs):
         try:
-            res = yield func(self, *args, **kwargs)
+            res = yield func(self, *args, delete_existing=False, **kwargs)
         except self.database_engine.module.IntegrityError:
             logger.exception("IntegrityError, retrying.")
             res = yield func(self, *args, delete_existing=True, **kwargs)
@@ -241,9 +102,6 @@ class EventsStore(
     def __init__(self, db_conn, hs):
         super(EventsStore, self).__init__(db_conn, hs)
 
-        self._event_persist_queue = _EventPeristenceQueue()
-        self._state_resolution_handler = hs.get_state_resolution_handler()
-
         # Collect metrics on the number of forward extremities that exist.
         # Counter of number of extremities to count
         self._current_forward_extremities_amount = c_counter()
@@ -286,340 +144,106 @@ class EventsStore(
         res = yield self.runInteraction("read_forward_extremities", fetch)
         self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
 
-    @defer.inlineCallbacks
-    def persist_events(self, events_and_contexts, backfilled=False):
-        """
-        Write events to the database
-        Args:
-            events_and_contexts: list of tuples of (event, context)
-            backfilled (bool): Whether the results are retrieved from federation
-                via backfill or not. Used to determine if they're "new" events
-                which might update the current state etc.
-
-        Returns:
-            Deferred[int]: the stream ordering of the latest persisted event
-        """
-        partitioned = {}
-        for event, ctx in events_and_contexts:
-            partitioned.setdefault(event.room_id, []).append((event, ctx))
-
-        deferreds = []
-        for room_id, evs_ctxs in iteritems(partitioned):
-            d = self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs, backfilled=backfilled
-            )
-            deferreds.append(d)
-
-        for room_id in partitioned:
-            self._maybe_start_persisting(room_id)
-
-        yield make_deferred_yieldable(
-            defer.gatherResults(deferreds, consumeErrors=True)
-        )
-
-        max_persisted_id = yield self._stream_id_gen.get_current_token()
-
-        return max_persisted_id
-
-    @defer.inlineCallbacks
-    @log_function
-    def persist_event(self, event, context, backfilled=False):
-        """
-
-        Args:
-            event (EventBase):
-            context (EventContext):
-            backfilled (bool):
-
-        Returns:
-            Deferred: resolves to (int, int): the stream ordering of ``event``,
-            and the stream ordering of the latest persisted event
-        """
-        deferred = self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)], backfilled=backfilled
-        )
-
-        self._maybe_start_persisting(event.room_id)
-
-        yield make_deferred_yieldable(deferred)
-
-        max_persisted_id = yield self._stream_id_gen.get_current_token()
-        return (event.internal_metadata.stream_ordering, max_persisted_id)
-
-    def _maybe_start_persisting(self, room_id):
-        @defer.inlineCallbacks
-        def persisting_queue(item):
-            with Measure(self._clock, "persist_events"):
-                yield self._persist_events(
-                    item.events_and_contexts, backfilled=item.backfilled
-                )
-
-        self._event_persist_queue.handle_queue(room_id, persisting_queue)
-
     @_retry_on_integrity_error
     @defer.inlineCallbacks
-    def _persist_events(
-        self, events_and_contexts, backfilled=False, delete_existing=False
+    def _persist_events_and_state_updates(
+        self,
+        events_and_contexts,
+        current_state_for_room,
+        state_delta_for_room,
+        new_forward_extremeties,
+        backfilled=False,
+        delete_existing=False,
     ):
-        """Persist events to db
+        """Persist a set of events alongside updates to the current state and
+        forward extremities tables.
 
         Args:
             events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
+            current_state_for_room (dict[str, dict]): Map from room_id to the
+                current state of the room based on forward extremities
+            state_delta_for_room (dict[str, tuple]): Map from room_id to tuple
+                of `(to_delete, to_insert)` where to_delete is a list
+                of type/state keys to remove from current state, and to_insert
+                is a map (type,key)->event_id giving the state delta in each
+                room.
+            new_forward_extremities (dict[str, list[str]]): Map from room_id
+                to list of event IDs that are the new forward extremities of
+                the room.
+            backfilled (bool)
             delete_existing (bool):
 
         Returns:
             Deferred: resolves when the events have been persisted
         """
-        if not events_and_contexts:
-            return
 
-        chunks = [
-            events_and_contexts[x : x + 100]
-            for x in range(0, len(events_and_contexts), 100)
-        ]
-
-        for chunk in chunks:
-            # We can't easily parallelize these since different chunks
-            # might contain the same event. :(
-
-            # NB: Assumes that we are only persisting events for one room
-            # at a time.
-
-            # map room_id->list[event_ids] giving the new forward
-            # extremities in each room
-            new_forward_extremeties = {}
+        # We want to calculate the stream orderings as late as possible, as
+        # we only notify after all events with a lesser stream ordering have
+        # been persisted. I.e. if we spend 10s inside the with block then
+        # that will delay all subsequent events from being notified about.
+        # Hence why we do it down here rather than wrapping the entire
+        # function.
+        #
+        # Its safe to do this after calculating the state deltas etc as we
+        # only need to protect the *persistence* of the events. This is to
+        # ensure that queries of the form "fetch events since X" don't
+        # return events and stream positions after events that are still in
+        # flight, as otherwise subsequent requests "fetch event since Y"
+        # will not return those events.
+        #
+        # Note: Multiple instances of this function cannot be in flight at
+        # the same time for the same room.
+        if backfilled:
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+                len(events_and_contexts)
+            )
+        else:
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+                len(events_and_contexts)
+            )
 
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room = {}
+        with stream_ordering_manager as stream_orderings:
+            for (event, context), stream in zip(events_and_contexts, stream_orderings):
+                event.internal_metadata.stream_ordering = stream
 
-            # map room_id->(to_delete, to_insert) where to_delete is a list
-            # of type/state keys to remove from current state, and to_insert
-            # is a map (type,key)->event_id giving the state delta in each
-            # room
-            state_delta_for_room = {}
+            yield self.runInteraction(
+                "persist_events",
+                self._persist_events_txn,
+                events_and_contexts=events_and_contexts,
+                backfilled=backfilled,
+                delete_existing=delete_existing,
+                state_delta_for_room=state_delta_for_room,
+                new_forward_extremeties=new_forward_extremeties,
+            )
+            persist_event_counter.inc(len(events_and_contexts))
 
             if not backfilled:
-                with Measure(self._clock, "_calculate_state_and_extrem"):
-                    # Work out the new "current state" for each room.
-                    # We do this by working out what the new extremities are and then
-                    # calculating the state from that.
-                    events_by_room = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in iteritems(events_by_room):
-                        latest_event_ids = yield self.get_latest_event_ids_in_room(
-                            room_id
-                        )
-                        new_latest_event_ids = yield self._calculate_new_extremities(
-                            room_id, ev_ctx_rm, latest_event_ids
-                        )
-
-                        latest_event_ids = set(latest_event_ids)
-                        if new_latest_event_ids == latest_event_ids:
-                            # No change in extremities, so no change in state
-                            continue
-
-                        # there should always be at least one forward extremity.
-                        # (except during the initial persistence of the send_join
-                        # results, in which case there will be no existing
-                        # extremities, so we'll `continue` above and skip this bit.)
-                        assert new_latest_event_ids, "No forward extremities left!"
-
-                        new_forward_extremeties[room_id] = new_latest_event_ids
-
-                        len_1 = (
-                            len(latest_event_ids) == 1
-                            and len(new_latest_event_ids) == 1
-                        )
-                        if len_1:
-                            all_single_prev_not_state = all(
-                                len(event.prev_event_ids()) == 1
-                                and not event.is_state()
-                                for event, ctx in ev_ctx_rm
-                            )
-                            # Don't bother calculating state if they're just
-                            # a long chain of single ancestor non-state events.
-                            if all_single_prev_not_state:
-                                continue
-
-                        state_delta_counter.inc()
-                        if len(new_latest_event_ids) == 1:
-                            state_delta_single_event_counter.inc()
-
-                            # This is a fairly handwavey check to see if we could
-                            # have guessed what the delta would have been when
-                            # processing one of these events.
-                            # What we're interested in is if the latest extremities
-                            # were the same when we created the event as they are
-                            # now. When this server creates a new event (as opposed
-                            # to receiving it over federation) it will use the
-                            # forward extremities as the prev_events, so we can
-                            # guess this by looking at the prev_events and checking
-                            # if they match the current forward extremities.
-                            for ev, _ in ev_ctx_rm:
-                                prev_event_ids = set(ev.prev_event_ids())
-                                if latest_event_ids == prev_event_ids:
-                                    state_delta_reuse_delta_counter.inc()
-                                    break
-
-                        logger.info("Calculating state delta for room %s", room_id)
-                        with Measure(
-                            self._clock, "persist_events.get_new_state_after_events"
-                        ):
-                            res = yield self._get_new_state_after_events(
-                                room_id,
-                                ev_ctx_rm,
-                                latest_event_ids,
-                                new_latest_event_ids,
-                            )
-                            current_state, delta_ids = res
-
-                        # If either are not None then there has been a change,
-                        # and we need to work out the delta (or use that
-                        # given)
-                        if delta_ids is not None:
-                            # If there is a delta we know that we've
-                            # only added or replaced state, never
-                            # removed keys entirely.
-                            state_delta_for_room[room_id] = ([], delta_ids)
-                        elif current_state is not None:
-                            with Measure(
-                                self._clock, "persist_events.calculate_state_delta"
-                            ):
-                                delta = yield self._calculate_state_delta(
-                                    room_id, current_state
-                                )
-                            state_delta_for_room[room_id] = delta
-
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
-            # We want to calculate the stream orderings as late as possible, as
-            # we only notify after all events with a lesser stream ordering have
-            # been persisted. I.e. if we spend 10s inside the with block then
-            # that will delay all subsequent events from being notified about.
-            # Hence why we do it down here rather than wrapping the entire
-            # function.
-            #
-            # Its safe to do this after calculating the state deltas etc as we
-            # only need to protect the *persistence* of the events. This is to
-            # ensure that queries of the form "fetch events since X" don't
-            # return events and stream positions after events that are still in
-            # flight, as otherwise subsequent requests "fetch event since Y"
-            # will not return those events.
-            #
-            # Note: Multiple instances of this function cannot be in flight at
-            # the same time for the same room.
-            if backfilled:
-                stream_ordering_manager = self._backfill_id_gen.get_next_mult(
-                    len(chunk)
+                # backfilled events have negative stream orderings, so we don't
+                # want to set the event_persisted_position to that.
+                synapse.metrics.event_persisted_position.set(
+                    events_and_contexts[-1][0].internal_metadata.stream_ordering
                 )
-            else:
-                stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
-
-            with stream_ordering_manager as stream_orderings:
-                for (event, context), stream in zip(chunk, stream_orderings):
-                    event.internal_metadata.stream_ordering = stream
-
-                yield self.runInteraction(
-                    "persist_events",
-                    self._persist_events_txn,
-                    events_and_contexts=chunk,
-                    backfilled=backfilled,
-                    delete_existing=delete_existing,
-                    state_delta_for_room=state_delta_for_room,
-                    new_forward_extremeties=new_forward_extremeties,
-                )
-                persist_event_counter.inc(len(chunk))
-
-                if not backfilled:
-                    # backfilled events have negative stream orderings, so we don't
-                    # want to set the event_persisted_position to that.
-                    synapse.metrics.event_persisted_position.set(
-                        chunk[-1][0].internal_metadata.stream_ordering
-                    )
 
-                for event, context in chunk:
-                    if context.app_service:
-                        origin_type = "local"
-                        origin_entity = context.app_service.id
-                    elif self.hs.is_mine_id(event.sender):
-                        origin_type = "local"
-                        origin_entity = "*client*"
-                    else:
-                        origin_type = "remote"
-                        origin_entity = get_domain_from_id(event.sender)
-
-                    event_counter.labels(event.type, origin_type, origin_entity).inc()
-
-                for room_id, new_state in iteritems(current_state_for_room):
-                    self.get_current_state_ids.prefill((room_id,), new_state)
-
-                for room_id, latest_event_ids in iteritems(new_forward_extremeties):
-                    self.get_latest_event_ids_in_room.prefill(
-                        (room_id,), list(latest_event_ids)
-                    )
-
-    @defer.inlineCallbacks
-    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
-        """Calculates the new forward extremities for a room given events to
-        persist.
-
-        Assumes that we are only persisting events for one room at a time.
-        """
-
-        # we're only interested in new events which aren't outliers and which aren't
-        # being rejected.
-        new_events = [
-            event
-            for event, ctx in event_contexts
-            if not event.internal_metadata.is_outlier()
-            and not ctx.rejected
-            and not event.internal_metadata.is_soft_failed()
-        ]
-
-        latest_event_ids = set(latest_event_ids)
-
-        # start with the existing forward extremities
-        result = set(latest_event_ids)
-
-        # add all the new events to the list
-        result.update(event.event_id for event in new_events)
-
-        # Now remove all events which are prev_events of any of the new events
-        result.difference_update(
-            e_id for event in new_events for e_id in event.prev_event_ids()
-        )
+            for event, context in events_and_contexts:
+                if context.app_service:
+                    origin_type = "local"
+                    origin_entity = context.app_service.id
+                elif self.hs.is_mine_id(event.sender):
+                    origin_type = "local"
+                    origin_entity = "*client*"
+                else:
+                    origin_type = "remote"
+                    origin_entity = get_domain_from_id(event.sender)
 
-        # Remove any events which are prev_events of any existing events.
-        existing_prevs = yield self._get_events_which_are_prevs(result)
-        result.difference_update(existing_prevs)
+                event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-        # Finally handle the case where the new events have soft-failed prev
-        # events. If they do we need to remove them and their prev events,
-        # otherwise we end up with dangling extremities.
-        existing_prevs = yield self._get_prevs_before_rejected(
-            e_id for event in new_events for e_id in event.prev_event_ids()
-        )
-        result.difference_update(existing_prevs)
+            for room_id, new_state in iteritems(current_state_for_room):
+                self.get_current_state_ids.prefill((room_id,), new_state)
 
-        # We only update metrics for events that change forward extremities
-        # (e.g. we ignore backfill/outliers/etc)
-        if result != latest_event_ids:
-            forward_extremities_counter.observe(len(result))
-            stale = latest_event_ids & result
-            stale_forward_extremities_counter.observe(len(stale))
-
-        return result
+            for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+                self.get_latest_event_ids_in_room.prefill(
+                    (room_id,), list(latest_event_ids)
+                )
 
     @defer.inlineCallbacks
     def _get_events_which_are_prevs(self, event_ids):
@@ -725,188 +349,6 @@ class EventsStore(
 
         return existing_prevs
 
-    @defer.inlineCallbacks
-    def _get_new_state_after_events(
-        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
-    ):
-        """Calculate the current state dict after adding some new events to
-        a room
-
-        Args:
-            room_id (str):
-                room to which the events are being added. Used for logging etc
-
-            events_context (list[(EventBase, EventContext)]):
-                events and contexts which are being added to the room
-
-            old_latest_event_ids (iterable[str]):
-                the old forward extremities for the room.
-
-            new_latest_event_ids (iterable[str]):
-                the new forward extremities for the room.
-
-        Returns:
-            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
-            Returns a tuple of two state maps, the first being the full new current
-            state and the second being the delta to the existing current state.
-            If both are None then there has been no change.
-
-            If there has been a change then we only return the delta if its
-            already been calculated. Conversely if we do know the delta then
-            the new current state is only returned if we've already calculated
-            it.
-        """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
-        # Map from (prev state group, new state group) -> delta state dict
-        state_group_deltas = {}
-
-        for ev, ctx in events_context:
-            if ctx.state_group is None:
-                # This should only happen for outlier events.
-                if not ev.internal_metadata.is_outlier():
-                    raise Exception(
-                        "Context for new event %s has no state "
-                        "group" % (ev.event_id,)
-                    )
-                continue
-
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
-            if ctx.prev_group:
-                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
-
-        # We need to map the event_ids to their state groups. First, let's
-        # check if the event is one we're persisting, in which case we can
-        # pull the state group from its context.
-        # Otherwise we need to pull the state group from the database.
-
-        # Set of events we need to fetch groups for. (We know none of the old
-        # extremities are going to be in events_context).
-        missing_event_ids = set(old_latest_event_ids)
-
-        event_id_to_state_group = {}
-        for event_id in new_latest_event_ids:
-            # First search in the list of new events we're adding.
-            for ev, ctx in events_context:
-                if event_id == ev.event_id and ctx.state_group is not None:
-                    event_id_to_state_group[event_id] = ctx.state_group
-                    break
-            else:
-                # If we couldn't find it, then we'll need to pull
-                # the state from the database
-                missing_event_ids.add(event_id)
-
-        if missing_event_ids:
-            # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
-            event_id_to_state_group.update(event_to_groups)
-
-        # State groups of old_latest_event_ids
-        old_state_groups = set(
-            event_id_to_state_group[evid] for evid in old_latest_event_ids
-        )
-
-        # State groups of new_latest_event_ids
-        new_state_groups = set(
-            event_id_to_state_group[evid] for evid in new_latest_event_ids
-        )
-
-        # If they old and new groups are the same then we don't need to do
-        # anything.
-        if old_state_groups == new_state_groups:
-            return None, None
-
-        if len(new_state_groups) == 1 and len(old_state_groups) == 1:
-            # If we're going from one state group to another, lets check if
-            # we have a delta for that transition. If we do then we can just
-            # return that.
-
-            new_state_group = next(iter(new_state_groups))
-            old_state_group = next(iter(old_state_groups))
-
-            delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
-            if delta_ids is not None:
-                # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids
-
-        # Now that we have calculated new_state_groups we need to get
-        # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = yield self._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
-
-        if len(new_state_groups) == 1:
-            # If there is only one state group, then we know what the current
-            # state is.
-            return state_groups_map[new_state_groups.pop()], None
-
-        # Ok, we need to defer to the state handler to resolve our state sets.
-
-        state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
-
-        events_map = {ev.event_id: ev for ev, _ in events_context}
-
-        # We need to get the room version, which is in the create event.
-        # Normally that'd be in the database, but its also possible that we're
-        # currently trying to persist it.
-        room_version = None
-        for ev, _ in events_context:
-            if ev.type == EventTypes.Create and ev.state_key == "":
-                room_version = ev.content.get("room_version", "1")
-                break
-
-        if not room_version:
-            room_version = yield self.get_room_version(room_id)
-
-        logger.debug("calling resolve_state_groups from preserve_events")
-        res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id,
-            room_version,
-            state_groups,
-            events_map,
-            state_res_store=StateResolutionStore(self),
-        )
-
-        return res.state, None
-
-    @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, current_state):
-        """Calculate the new state deltas for a room.
-
-        Assumes that we are only persisting events for one room at a time.
-
-        Returns:
-            tuple[list, dict] (to_delete, to_insert): where to_delete are the
-            type/state_keys to remove from current_state_events and `to_insert`
-            are the updates to current_state_events.
-        """
-        existing_state = yield self.get_current_state_ids(room_id)
-
-        to_delete = [key for key in existing_state if key not in current_state]
-
-        to_insert = {
-            key: ev_id
-            for key, ev_id in iteritems(current_state)
-            if ev_id != existing_state.get(key)
-        }
-
-        return to_delete, to_insert
-
     @log_function
     def _persist_events_txn(
         self,
@@ -1690,7 +1132,7 @@ class EventsStore(
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count
 
         ret = yield self.runInteraction("count_messages", _count_messages)
@@ -1711,7 +1153,7 @@ class EventsStore(
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count
 
         ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
@@ -1726,7 +1168,7 @@ class EventsStore(
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count
 
         ret = yield self.runInteraction("count_daily_active_rooms", _count)
@@ -2211,7 +1653,7 @@ class EventsStore(
         """,
             (room_id,),
         )
-        min_depth, = txn.fetchone()
+        (min_depth,) = txn.fetchone()
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
@@ -2403,7 +1845,6 @@ class EventsStore(
             "room_stats_earliest_token",
             "rooms",
             "stream_ordering_to_exterm",
-            "topics",
             "users_in_public_rooms",
             "users_who_share_private_rooms",
             # no useful index, but let's clear them anyway
@@ -2446,12 +1887,11 @@ class EventsStore(
 
         logger.info("[purge] done")
 
-    @defer.inlineCallbacks
-    def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1, event_id2):
         """Returns True if event_id1 is after event_id2 in the stream
         """
-        to_1, so_1 = yield self._get_event_ordering(event_id1)
-        to_2, so_2 = yield self._get_event_ordering(event_id2)
+        to_1, so_1 = await self._get_event_ordering(event_id1)
+        to_2, so_2 = await self._get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cachedInlineCallbacks(max_entries=5000)
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 31ea6f917f..51352b9966 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
             if not rows:
                 return 0
 
-            upper_event_id, = rows[-1]
+            (upper_event_id,) = rows[-1]
 
             # Update the redactions with the received_ts.
             #
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index aeae5a2b28..b3a2771f1b 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore):
                 WHERE group_id = ? AND category_id = ?
             """
             txn.execute(sql, (group_id, category_id))
-            order, = txn.fetchone()
+            (order,) = txn.fetchone()
 
         if existing:
             to_update = {}
@@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore):
                 WHERE group_id = ? AND role_id = ?
             """
             txn.execute(sql, (group_id, role_id))
-            order, = txn.fetchone()
+            (order,) = txn.fetchone()
 
         if existing:
             to_update = {}
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index e6ee1e4aaa..b41c3d317a 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
 
             txn.execute(sql)
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count
 
         return self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index cd95f1ce60..b520062d84 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -143,7 +143,7 @@ class PushRulesWorkerStore(
                     " WHERE user_id = ? AND ? < stream_id"
                 )
                 txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
+                (count,) = txn.fetchone()
                 return bool(count)
 
             return self.runInteraction(
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f005c1ae0a..d76861cdc0 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore):
 
                 r["data"] = json.loads(dataJson)
             except Exception as e:
-                logger.warn(
+                logger.warning(
                     "Invalid JSON in data for pusher %d: %s, %s",
                     r["id"],
                     dataJson,
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 6c5b29288a..f70d41ecab 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 WHERE appservice_id IS NULL
             """
             )
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count
 
         ret = yield self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index bc04bfd7d4..2af24a20b7 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
                 if not row or not row[0]:
                     return processed, True
 
-                next_room, = row
+                (next_room,) = row
 
                 sql = """
                     UPDATE current_state_events
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 0000000000..e8b1fd35d8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    display_name TEXT,
+    last_seen BIGINT,
+    ip TEXT,
+    user_agent TEXT,
+    hidden BOOLEAN DEFAULT 0,
+    CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 0e08497452..d1d7c6863d 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                         " ON event_search USING GIN (vector)"
                     )
                 except psycopg2.ProgrammingError as e:
-                    logger.warn(
+                    logger.warning(
                         "Ignoring error %r when trying to switch from GIST to GIN", e
                     )
 
@@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore):
                     )
                 )
                 txn.execute(query, (value, search_query))
-                headline, = txn.fetchall()[0]
+                (headline,) = txn.fetchall()[0]
 
                 # Now we need to pick the possible highlights out of the haedline
                 # result.
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9b2207075b..3132848034 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -725,16 +725,18 @@ class StateGroupWorkerStore(
         member_filter, non_member_filter = state_filter.get_member_split()
 
         # Now we look them up in the member and non-member caches
-        non_member_state, incomplete_groups_nm, = (
-            yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_cache, state_filter=non_member_filter
-            )
+        (
+            non_member_state,
+            incomplete_groups_nm,
+        ) = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_cache, state_filter=non_member_filter
         )
 
-        member_state, incomplete_groups_m, = (
-            yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_members_cache, state_filter=member_filter
-            )
+        (
+            member_state,
+            incomplete_groups_m,
+        ) = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_members_cache, state_filter=member_filter
         )
 
         state = dict(non_member_state)
@@ -1076,7 +1078,7 @@ class StateBackgroundUpdateStore(
                     " WHERE id < ? AND room_id = ?",
                     (state_group, room_id),
                 )
-                prev_group, = txn.fetchone()
+                (prev_group,) = txn.fetchone()
                 new_last_state_group = state_group
 
                 if prev_group:
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 4d59b7833f..45b3de7d56 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore):
                 (room_id,),
             )
 
-            current_state_events_count, = txn.fetchone()
+            (current_state_events_count,) = txn.fetchone()
 
             users_in_room = self.get_users_in_room_txn(txn, room_id)
 
@@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore):
                 """,
                 (user_id,),
             )
-            count, = txn.fetchone()
+            (count,) = txn.fetchone()
             return count, pos
 
         joined_rooms, pos = yield self.runInteraction(
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
new file mode 100644
index 0000000000..fa03ca9ff7
--- /dev/null
+++ b/synapse/storage/persist_events.py
@@ -0,0 +1,649 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import deque, namedtuple
+
+from six import iteritems
+from six.moves import range
+
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
+from synapse.storage.data_stores import DataStores
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+    "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+    "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+    "synapse_storage_events_forward_extremities_persisted",
+    "Number of forward extremities for each new event",
+    buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+    "synapse_storage_events_stale_forward_extremities_persisted",
+    "Number of unchanged forward extremities for each new event",
+    buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
+class _EventPeristenceQueue(object):
+    """Queues up events so that they can be persisted in bulk with only one
+    concurrent transaction per room.
+    """
+
+    _EventPersistQueueItem = namedtuple(
+        "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+    )
+
+    def __init__(self):
+        self._event_persist_queues = {}
+        self._currently_persisting_rooms = set()
+
+    def add_to_queue(self, room_id, events_and_contexts, backfilled):
+        """Add events to the queue, with the given persist_event options.
+
+        NB: due to the normal usage pattern of this method, it does *not*
+        follow the synapse logcontext rules, and leaves the logcontext in
+        place whether or not the returned deferred is ready.
+
+        Args:
+            room_id (str):
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+
+        Returns:
+            defer.Deferred: a deferred which will resolve once the events are
+                persisted. Runs its callbacks *without* a logcontext.
+        """
+        queue = self._event_persist_queues.setdefault(room_id, deque())
+        if queue:
+            # if the last item in the queue has the same `backfilled` setting,
+            # we can just add these new events to that item.
+            end_item = queue[-1]
+            if end_item.backfilled == backfilled:
+                end_item.events_and_contexts.extend(events_and_contexts)
+                return end_item.deferred.observe()
+
+        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+
+        queue.append(
+            self._EventPersistQueueItem(
+                events_and_contexts=events_and_contexts,
+                backfilled=backfilled,
+                deferred=deferred,
+            )
+        )
+
+        return deferred.observe()
+
+    def handle_queue(self, room_id, per_item_callback):
+        """Attempts to handle the queue for a room if not already being handled.
+
+        The given callback will be invoked with for each item in the queue,
+        of type _EventPersistQueueItem. The per_item_callback will continuously
+        be called with new items, unless the queue becomnes empty. The return
+        value of the function will be given to the deferreds waiting on the item,
+        exceptions will be passed to the deferreds as well.
+
+        This function should therefore be called whenever anything is added
+        to the queue.
+
+        If another callback is currently handling the queue then it will not be
+        invoked.
+        """
+
+        if room_id in self._currently_persisting_rooms:
+            return
+
+        self._currently_persisting_rooms.add(room_id)
+
+        @defer.inlineCallbacks
+        def handle_queue_loop():
+            try:
+                queue = self._get_drainining_queue(room_id)
+                for item in queue:
+                    try:
+                        ret = yield per_item_callback(item)
+                    except Exception:
+                        with PreserveLoggingContext():
+                            item.deferred.errback()
+                    else:
+                        with PreserveLoggingContext():
+                            item.deferred.callback(ret)
+            finally:
+                queue = self._event_persist_queues.pop(room_id, None)
+                if queue:
+                    self._event_persist_queues[room_id] = queue
+                self._currently_persisting_rooms.discard(room_id)
+
+        # set handle_queue_loop off in the background
+        run_as_background_process("persist_events", handle_queue_loop)
+
+    def _get_drainining_queue(self, room_id):
+        queue = self._event_persist_queues.setdefault(room_id, deque())
+
+        try:
+            while True:
+                yield queue.popleft()
+        except IndexError:
+            # Queue has been drained.
+            pass
+
+
+class EventsPersistenceStorage(object):
+    """High level interface for handling persisting newly received events.
+
+    Takes care of batching up events by room, and calculating the necessary
+    current state and forward extremity changes.
+    """
+
+    def __init__(self, hs, stores: DataStores):
+        # We ultimately want to split out the state store from the main store,
+        # so we use separate variables here even though they point to the same
+        # store for now.
+        self.main_store = stores.main
+        self.state_store = stores.main
+
+        self._clock = hs.get_clock()
+        self.is_mine_id = hs.is_mine_id
+        self._event_persist_queue = _EventPeristenceQueue()
+        self._state_resolution_handler = hs.get_state_resolution_handler()
+
+    @defer.inlineCallbacks
+    def persist_events(self, events_and_contexts, backfilled=False):
+        """
+        Write events to the database
+        Args:
+            events_and_contexts: list of tuples of (event, context)
+            backfilled (bool): Whether the results are retrieved from federation
+                via backfill or not. Used to determine if they're "new" events
+                which might update the current state etc.
+
+        Returns:
+            Deferred[int]: the stream ordering of the latest persisted event
+        """
+        partitioned = {}
+        for event, ctx in events_and_contexts:
+            partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+        deferreds = []
+        for room_id, evs_ctxs in iteritems(partitioned):
+            d = self._event_persist_queue.add_to_queue(
+                room_id, evs_ctxs, backfilled=backfilled
+            )
+            deferreds.append(d)
+
+        for room_id in partitioned:
+            self._maybe_start_persisting(room_id)
+
+        yield make_deferred_yieldable(
+            defer.gatherResults(deferreds, consumeErrors=True)
+        )
+
+        max_persisted_id = yield self.main_store.get_current_events_token()
+
+        return max_persisted_id
+
+    @defer.inlineCallbacks
+    def persist_event(self, event, context, backfilled=False):
+        """
+
+        Args:
+            event (EventBase):
+            context (EventContext):
+            backfilled (bool):
+
+        Returns:
+            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            and the stream ordering of the latest persisted event
+        """
+        deferred = self._event_persist_queue.add_to_queue(
+            event.room_id, [(event, context)], backfilled=backfilled
+        )
+
+        self._maybe_start_persisting(event.room_id)
+
+        yield make_deferred_yieldable(deferred)
+
+        max_persisted_id = yield self.main_store.get_current_events_token()
+        return (event.internal_metadata.stream_ordering, max_persisted_id)
+
+    def _maybe_start_persisting(self, room_id):
+        @defer.inlineCallbacks
+        def persisting_queue(item):
+            with Measure(self._clock, "persist_events"):
+                yield self._persist_events(
+                    item.events_and_contexts, backfilled=item.backfilled
+                )
+
+        self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+    @defer.inlineCallbacks
+    def _persist_events(self, events_and_contexts, backfilled=False):
+        """Calculates the change to current state and forward extremities, and
+        persists the given events and with those updates.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+            delete_existing (bool):
+
+        Returns:
+            Deferred: resolves when the events have been persisted
+        """
+        if not events_and_contexts:
+            return
+
+        chunks = [
+            events_and_contexts[x : x + 100]
+            for x in range(0, len(events_and_contexts), 100)
+        ]
+
+        for chunk in chunks:
+            # We can't easily parallelize these since different chunks
+            # might contain the same event. :(
+
+            # NB: Assumes that we are only persisting events for one room
+            # at a time.
+
+            # map room_id->list[event_ids] giving the new forward
+            # extremities in each room
+            new_forward_extremeties = {}
+
+            # map room_id->(type,state_key)->event_id tracking the full
+            # state in each room after adding these events.
+            # This is simply used to prefill the get_current_state_ids
+            # cache
+            current_state_for_room = {}
+
+            # map room_id->(to_delete, to_insert) where to_delete is a list
+            # of type/state keys to remove from current state, and to_insert
+            # is a map (type,key)->event_id giving the state delta in each
+            # room
+            state_delta_for_room = {}
+
+            if not backfilled:
+                with Measure(self._clock, "_calculate_state_and_extrem"):
+                    # Work out the new "current state" for each room.
+                    # We do this by working out what the new extremities are and then
+                    # calculating the state from that.
+                    events_by_room = {}
+                    for event, context in chunk:
+                        events_by_room.setdefault(event.room_id, []).append(
+                            (event, context)
+                        )
+
+                    for room_id, ev_ctx_rm in iteritems(events_by_room):
+                        latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
+                            room_id
+                        )
+                        new_latest_event_ids = yield self._calculate_new_extremities(
+                            room_id, ev_ctx_rm, latest_event_ids
+                        )
+
+                        latest_event_ids = set(latest_event_ids)
+                        if new_latest_event_ids == latest_event_ids:
+                            # No change in extremities, so no change in state
+                            continue
+
+                        # there should always be at least one forward extremity.
+                        # (except during the initial persistence of the send_join
+                        # results, in which case there will be no existing
+                        # extremities, so we'll `continue` above and skip this bit.)
+                        assert new_latest_event_ids, "No forward extremities left!"
+
+                        new_forward_extremeties[room_id] = new_latest_event_ids
+
+                        len_1 = (
+                            len(latest_event_ids) == 1
+                            and len(new_latest_event_ids) == 1
+                        )
+                        if len_1:
+                            all_single_prev_not_state = all(
+                                len(event.prev_event_ids()) == 1
+                                and not event.is_state()
+                                for event, ctx in ev_ctx_rm
+                            )
+                            # Don't bother calculating state if they're just
+                            # a long chain of single ancestor non-state events.
+                            if all_single_prev_not_state:
+                                continue
+
+                        state_delta_counter.inc()
+                        if len(new_latest_event_ids) == 1:
+                            state_delta_single_event_counter.inc()
+
+                            # This is a fairly handwavey check to see if we could
+                            # have guessed what the delta would have been when
+                            # processing one of these events.
+                            # What we're interested in is if the latest extremities
+                            # were the same when we created the event as they are
+                            # now. When this server creates a new event (as opposed
+                            # to receiving it over federation) it will use the
+                            # forward extremities as the prev_events, so we can
+                            # guess this by looking at the prev_events and checking
+                            # if they match the current forward extremities.
+                            for ev, _ in ev_ctx_rm:
+                                prev_event_ids = set(ev.prev_event_ids())
+                                if latest_event_ids == prev_event_ids:
+                                    state_delta_reuse_delta_counter.inc()
+                                    break
+
+                        logger.info("Calculating state delta for room %s", room_id)
+                        with Measure(
+                            self._clock, "persist_events.get_new_state_after_events"
+                        ):
+                            res = yield self._get_new_state_after_events(
+                                room_id,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
+                            )
+                            current_state, delta_ids = res
+
+                        # If either are not None then there has been a change,
+                        # and we need to work out the delta (or use that
+                        # given)
+                        if delta_ids is not None:
+                            # If there is a delta we know that we've
+                            # only added or replaced state, never
+                            # removed keys entirely.
+                            state_delta_for_room[room_id] = ([], delta_ids)
+                        elif current_state is not None:
+                            with Measure(
+                                self._clock, "persist_events.calculate_state_delta"
+                            ):
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state
+                                )
+                            state_delta_for_room[room_id] = delta
+
+                        # If we have the current_state then lets prefill
+                        # the cache with it.
+                        if current_state is not None:
+                            current_state_for_room[room_id] = current_state
+
+            yield self.main_store._persist_events_and_state_updates(
+                chunk,
+                current_state_for_room=current_state_for_room,
+                state_delta_for_room=state_delta_for_room,
+                new_forward_extremeties=new_forward_extremeties,
+                backfilled=backfilled,
+            )
+
+    @defer.inlineCallbacks
+    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+        """Calculates the new forward extremities for a room given events to
+        persist.
+
+        Assumes that we are only persisting events for one room at a time.
+        """
+
+        # we're only interested in new events which aren't outliers and which aren't
+        # being rejected.
+        new_events = [
+            event
+            for event, ctx in event_contexts
+            if not event.internal_metadata.is_outlier()
+            and not ctx.rejected
+            and not event.internal_metadata.is_soft_failed()
+        ]
+
+        latest_event_ids = set(latest_event_ids)
+
+        # start with the existing forward extremities
+        result = set(latest_event_ids)
+
+        # add all the new events to the list
+        result.update(event.event_id for event in new_events)
+
+        # Now remove all events which are prev_events of any of the new events
+        result.difference_update(
+            e_id for event in new_events for e_id in event.prev_event_ids()
+        )
+
+        # Remove any events which are prev_events of any existing events.
+        existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
+        result.difference_update(existing_prevs)
+
+        # Finally handle the case where the new events have soft-failed prev
+        # events. If they do we need to remove them and their prev events,
+        # otherwise we end up with dangling extremities.
+        existing_prevs = yield self.main_store._get_prevs_before_rejected(
+            e_id for event in new_events for e_id in event.prev_event_ids()
+        )
+        result.difference_update(existing_prevs)
+
+        # We only update metrics for events that change forward extremities
+        # (e.g. we ignore backfill/outliers/etc)
+        if result != latest_event_ids:
+            forward_extremities_counter.observe(len(result))
+            stale = latest_event_ids & result
+            stale_forward_extremities_counter.observe(len(stale))
+
+        return result
+
+    @defer.inlineCallbacks
+    def _get_new_state_after_events(
+        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
+    ):
+        """Calculate the current state dict after adding some new events to
+        a room
+
+        Args:
+            room_id (str):
+                room to which the events are being added. Used for logging etc
+
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            old_latest_event_ids (iterable[str]):
+                the old forward extremities for the room.
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
+
+        Returns:
+            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+            Returns a tuple of two state maps, the first being the full new current
+            state and the second being the delta to the existing current state.
+            If both are None then there has been no change.
+
+            If there has been a change then we only return the delta if its
+            already been calculated. Conversely if we do know the delta then
+            the new current state is only returned if we've already calculated
+            it.
+        """
+        # map from state_group to ((type, key) -> event_id) state map
+        state_groups_map = {}
+
+        # Map from (prev state group, new state group) -> delta state dict
+        state_group_deltas = {}
+
+        for ev, ctx in events_context:
+            if ctx.state_group is None:
+                # This should only happen for outlier events.
+                if not ev.internal_metadata.is_outlier():
+                    raise Exception(
+                        "Context for new event %s has no state "
+                        "group" % (ev.event_id,)
+                    )
+                continue
+
+            if ctx.state_group in state_groups_map:
+                continue
+
+            # We're only interested in pulling out state that has already
+            # been cached in the context. We'll pull stuff out of the DB later
+            # if necessary.
+            current_state_ids = ctx.get_cached_current_state_ids()
+            if current_state_ids is not None:
+                state_groups_map[ctx.state_group] = current_state_ids
+
+            if ctx.prev_group:
+                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+        # We need to map the event_ids to their state groups. First, let's
+        # check if the event is one we're persisting, in which case we can
+        # pull the state group from its context.
+        # Otherwise we need to pull the state group from the database.
+
+        # Set of events we need to fetch groups for. (We know none of the old
+        # extremities are going to be in events_context).
+        missing_event_ids = set(old_latest_event_ids)
+
+        event_id_to_state_group = {}
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding.
+            for ev, ctx in events_context:
+                if event_id == ev.event_id and ctx.state_group is not None:
+                    event_id_to_state_group[event_id] = ctx.state_group
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                missing_event_ids.add(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state groups for any missing events from DB
+            event_to_groups = yield self.main_store._get_state_group_for_events(
+                missing_event_ids
+            )
+            event_id_to_state_group.update(event_to_groups)
+
+        # State groups of old_latest_event_ids
+        old_state_groups = set(
+            event_id_to_state_group[evid] for evid in old_latest_event_ids
+        )
+
+        # State groups of new_latest_event_ids
+        new_state_groups = set(
+            event_id_to_state_group[evid] for evid in new_latest_event_ids
+        )
+
+        # If they old and new groups are the same then we don't need to do
+        # anything.
+        if old_state_groups == new_state_groups:
+            return None, None
+
+        if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+            # If we're going from one state group to another, lets check if
+            # we have a delta for that transition. If we do then we can just
+            # return that.
+
+            new_state_group = next(iter(new_state_groups))
+            old_state_group = next(iter(old_state_groups))
+
+            delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+            if delta_ids is not None:
+                # We have a delta from the existing to new current state,
+                # so lets just return that. If we happen to already have
+                # the current state in memory then lets also return that,
+                # but it doesn't matter if we don't.
+                new_state = state_groups_map.get(new_state_group)
+                return new_state, delta_ids
+
+        # Now that we have calculated new_state_groups we need to get
+        # their state IDs so we can resolve to a single state set.
+        missing_state = new_state_groups - set(state_groups_map)
+        if missing_state:
+            group_to_state = yield self.state_store._get_state_for_groups(missing_state)
+            state_groups_map.update(group_to_state)
+
+        if len(new_state_groups) == 1:
+            # If there is only one state group, then we know what the current
+            # state is.
+            return state_groups_map[new_state_groups.pop()], None
+
+        # Ok, we need to defer to the state handler to resolve our state sets.
+
+        state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+        events_map = {ev.event_id: ev for ev, _ in events_context}
+
+        # We need to get the room version, which is in the create event.
+        # Normally that'd be in the database, but its also possible that we're
+        # currently trying to persist it.
+        room_version = None
+        for ev, _ in events_context:
+            if ev.type == EventTypes.Create and ev.state_key == "":
+                room_version = ev.content.get("room_version", "1")
+                break
+
+        if not room_version:
+            room_version = yield self.main_store.get_room_version(room_id)
+
+        logger.debug("calling resolve_state_groups from preserve_events")
+        res = yield self._state_resolution_handler.resolve_state_groups(
+            room_id,
+            room_version,
+            state_groups,
+            events_map,
+            state_res_store=StateResolutionStore(self.main_store),
+        )
+
+        return res.state, None
+
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, current_state):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            tuple[list, dict] (to_delete, to_insert): where to_delete are the
+            type/state_keys to remove from current_state_events and `to_insert`
+            are the updates to current_state_events.
+        """
+        existing_state = yield self.main_store.get_current_state_ids(room_id)
+
+        to_delete = [key for key in existing_state if key not in current_state]
+
+        to_insert = {
+            key: ev_id
+            for key, ev_id in iteritems(current_state)
+            if ev_id != existing_state.get(key)
+        }
+
+        return to_delete, to_insert
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a2df8fa827..3735846899 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -19,6 +19,8 @@ from six import iteritems, itervalues
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes
 
 logger = logging.getLogger(__name__)
@@ -322,3 +324,234 @@ class StateFilter(object):
         )
 
         return member_filter, non_member_filter
+
+
+class StateGroupStorage(object):
+    """High level interface to fetching state for event.
+    """
+
+    def __init__(self, hs, stores):
+        self.stores = stores
+
+    def get_state_group_delta(self, state_group):
+        """Given a state group try to return a previous group and a delta between
+        the old and the new.
+
+        Returns:
+            Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+                (prev_group, delta_ids)
+        """
+
+        return self.stores.main.get_state_group_delta(state_group)
+
+    @defer.inlineCallbacks
+    def get_state_groups_ids(self, _room_id, event_ids):
+        """Get the event IDs of all the state for the state groups for the given events
+
+        Args:
+            _room_id (str): id of the room for these events
+            event_ids (iterable[str]): ids of the events
+
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+        if not event_ids:
+            return {}
+
+        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+        groups = set(itervalues(event_to_groups))
+        group_to_state = yield self.stores.main._get_state_for_groups(groups)
+
+        return group_to_state
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_group(self, state_group):
+        """Get the event IDs of all the state in the given state group
+
+        Args:
+            state_group (int)
+
+        Returns:
+            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+        """
+        group_to_state = yield self._get_state_for_groups((state_group,))
+
+        return group_to_state[state_group]
+
+    @defer.inlineCallbacks
+    def get_state_groups(self, room_id, event_ids):
+        """ Get the state groups for the given list of event_ids
+        Returns:
+            Deferred[dict[int, list[EventBase]]]:
+                dict of state_group_id -> list of state events.
+        """
+        if not event_ids:
+            return {}
+
+        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+        state_event_map = yield self.stores.main.get_events(
+            [
+                ev_id
+                for group_ids in itervalues(group_to_ids)
+                for ev_id in itervalues(group_ids)
+            ],
+            get_prev_content=False,
+        )
+
+        return {
+            group: [
+                state_event_map[v]
+                for v in itervalues(event_id_map)
+                if v in state_event_map
+            ]
+            for group, event_id_map in iteritems(group_to_ids)
+        }
+
+    def _get_state_groups_from_groups(self, groups, state_filter):
+        """Returns the state groups for a given set of groups, filtering on
+        types of state events.
+
+        Args:
+            groups(list[int]): list of state group IDs to query
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+
+        return self.stores.main._get_state_groups_from_groups(groups, state_filter)
+
+    @defer.inlineCallbacks
+    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+        """Given a list of event_ids and type tuples, return a list of state
+        dicts for each event.
+        Args:
+            event_ids (list[string])
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+        """
+        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+        groups = set(itervalues(event_to_groups))
+        group_to_state = yield self.stores.main._get_state_for_groups(
+            groups, state_filter
+        )
+
+        state_event_map = yield self.stores.main.get_events(
+            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+            get_prev_content=False,
+        )
+
+        event_to_state = {
+            event_id: {
+                k: state_event_map[v]
+                for k, v in iteritems(group_to_state[group])
+                if v in state_event_map
+            }
+            for event_id, group in iteritems(event_to_groups)
+        }
+
+        return {event: event_to_state[event] for event in event_ids}
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+        """
+        Get the state dicts corresponding to a list of events, containing the event_ids
+        of the state events (as opposed to the events themselves)
+
+        Args:
+            event_ids(list(str)): events whose state should be returned
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            A deferred dict from event_id -> (type, state_key) -> event_id
+        """
+        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+        groups = set(itervalues(event_to_groups))
+        group_to_state = yield self.stores.main._get_state_for_groups(
+            groups, state_filter
+        )
+
+        event_to_state = {
+            event_id: group_to_state[group]
+            for event_id, group in iteritems(event_to_groups)
+        }
+
+        return {event: event_to_state[event] for event in event_ids}
+
+    @defer.inlineCallbacks
+    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+        """
+        Get the state dict corresponding to a particular event
+
+        Args:
+            event_id(str): event whose state should be returned
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            A deferred dict from (type, state_key) -> state_event
+        """
+        state_map = yield self.get_state_for_events([event_id], state_filter)
+        return state_map[event_id]
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+        """
+        Get the state dict corresponding to a particular event
+
+        Args:
+            event_id(str): event whose state should be returned
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            A deferred dict from (type, state_key) -> state_event
+        """
+        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+        return state_map[event_id]
+
+    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+        return self.stores.main._get_state_for_groups(groups, state_filter)
+
+    def store_state_group(
+        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+    ):
+        """Store a new set of state, returning a newly assigned state group.
+
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
+
+        Returns:
+            Deferred[int]: The state group ID
+        """
+        return self.stores.main.store_state_group(
+            event_id, room_id, prev_group, delta_ids, current_state_ids
+        )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index cbb0a4810a..9d851beaa5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
     else:
         cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
-    val, = cur.fetchone()
+    (val,) = cur.fetchone()
     cur.close()
     current_id = int(val) if val else step
     return (max if step > 0 else min)(current_id, step)