summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py56
-rw-r--r--synapse/storage/_base.py26
-rw-r--r--synapse/storage/account_data.py115
-rw-r--r--synapse/storage/appservice.py101
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/deviceinbox.py12
-rw-r--r--synapse/storage/devices.py2
-rw-r--r--synapse/storage/directory.py50
-rw-r--r--synapse/storage/end_to_end_keys.py2
-rw-r--r--synapse/storage/event_federation.py264
-rw-r--r--synapse/storage/event_push_actions.py465
-rw-r--r--synapse/storage/events.py460
-rw-r--r--synapse/storage/events_worker.py395
-rw-r--r--synapse/storage/group_server.py2
-rw-r--r--synapse/storage/profile.py50
-rw-r--r--synapse/storage/push_rule.py80
-rw-r--r--synapse/storage/pusher.py23
-rw-r--r--synapse/storage/receipts.py123
-rw-r--r--synapse/storage/registration.py122
-rw-r--r--synapse/storage/room.py271
-rw-r--r--synapse/storage/roommember.py390
-rw-r--r--synapse/storage/schema/delta/14/upgrade_appservice_db.py3
-rw-r--r--synapse/storage/schema/delta/25/fts.py4
-rw-r--r--synapse/storage/schema/delta/27/ts.py4
-rw-r--r--synapse/storage/schema/delta/31/search_update.py4
-rw-r--r--synapse/storage/schema/delta/33/event_fields.py4
-rw-r--r--synapse/storage/search.py2
-rw-r--r--synapse/storage/signatures.py12
-rw-r--r--synapse/storage/state.py34
-rw-r--r--synapse/storage/stream.py310
-rw-r--r--synapse/storage/tags.py18
-rw-r--r--synapse/storage/transactions.py2
-rw-r--r--synapse/storage/user_directory.py4
33 files changed, 1868 insertions, 1544 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f8fbd02ceb..de00cae447 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,7 +20,6 @@ from synapse.storage.devices import DeviceStore
 from .appservice import (
     ApplicationServiceStore, ApplicationServiceTransactionStore
 )
-from ._base import LoggingTransaction
 from .directory import DirectoryStore
 from .events import EventsStore
 from .presence import PresenceStore, UserPresenceState
@@ -104,12 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "events", "stream_ordering", step=-1,
             extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
         )
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn, "account_data_max_stream_id", "stream_id"
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -146,27 +140,6 @@ class DataStore(RoomMemberStore, RoomStore,
         else:
             self._cache_id_gen = None
 
-        events_max = self._stream_id_gen.get_current_token()
-        event_cache_prefill, min_event_val = self._get_cache_dict(
-            db_conn, "events",
-            entity_column="room_id",
-            stream_column="stream_ordering",
-            max_value=events_max,
-        )
-        self._events_stream_cache = StreamChangeCache(
-            "EventsRoomStreamChangeCache", min_event_val,
-            prefilled_cache=event_cache_prefill,
-        )
-
-        self._membership_stream_cache = StreamChangeCache(
-            "MembershipStreamChangeCache", events_max,
-        )
-
-        account_max = self._account_data_id_gen.get_current_token()
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache", account_max,
-        )
-
         self._presence_on_startup = self._get_active_presence(db_conn)
 
         presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -180,18 +153,6 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=presence_cache_prefill
         )
 
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn, "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._push_rules_stream_id_gen.get_current_token()[0],
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache", push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
             db_conn, "device_inbox",
@@ -226,6 +187,7 @@ class DataStore(RoomMemberStore, RoomStore,
             "DeviceListFederationStreamChangeCache", device_list_max,
         )
 
+        events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
             db_conn, "current_state_delta_stream",
             entity_column="room_id",
@@ -250,20 +212,6 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=_group_updates_prefill,
         )
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-            after_callbacks=[],
-            final_callbacks=[],
-        )
-        self._find_stream_orderings_for_times_txn(cur)
-        cur.close()
-
-        self.find_stream_orderings_looping_call = self._clock.looping_call(
-            self._find_stream_orderings_for_times, 10 * 60 * 1000
-        )
-
         self._stream_order_on_start = self.get_room_max_stream_ordering()
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 68125006eb..2fbebd4907 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -48,16 +48,16 @@ class LoggingTransaction(object):
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
     __slots__ = [
-        "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+        "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
     ]
 
     def __init__(self, txn, name, database_engine, after_callbacks,
-                 final_callbacks):
+                 exception_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
         object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "final_callbacks", final_callbacks)
+        object.__setattr__(self, "exception_callbacks", exception_callbacks)
 
     def call_after(self, callback, *args, **kwargs):
         """Call the given callback on the main twisted thread after the
@@ -66,8 +66,8 @@ class LoggingTransaction(object):
         """
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_finally(self, callback, *args, **kwargs):
-        self.final_callbacks.append((callback, args, kwargs))
+    def call_on_exception(self, callback, *args, **kwargs):
+        self.exception_callbacks.append((callback, args, kwargs))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -215,7 +215,7 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
-    def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+    def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
                          logging_context, func, *args, **kwargs):
         start = time.time() * 1000
         txn_id = self._TXN_ID
@@ -236,7 +236,7 @@ class SQLBaseStore(object):
                     txn = conn.cursor()
                     txn = LoggingTransaction(
                         txn, name, self.database_engine, after_callbacks,
-                        final_callbacks,
+                        exception_callbacks,
                     )
                     r = func(txn, *args, **kwargs)
                     conn.commit()
@@ -308,11 +308,11 @@ class SQLBaseStore(object):
         current_context = LoggingContext.current_context()
 
         after_callbacks = []
-        final_callbacks = []
+        exception_callbacks = []
 
         def inner_func(conn, *args, **kwargs):
             return self._new_transaction(
-                conn, desc, after_callbacks, final_callbacks, current_context,
+                conn, desc, after_callbacks, exception_callbacks, current_context,
                 func, *args, **kwargs
             )
 
@@ -321,9 +321,10 @@ class SQLBaseStore(object):
 
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
-        finally:
-            for after_callback, after_args, after_kwargs in final_callbacks:
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
                 after_callback(*after_args, **after_kwargs)
+            raise
 
         defer.returnValue(result)
 
@@ -1000,7 +1001,8 @@ class SQLBaseStore(object):
             # __exit__ called after the transaction finishes.
             ctx = self._cache_id_gen.get_next()
             stream_id = ctx.__enter__()
-            txn.call_finally(ctx.__exit__, None, None, None)
+            txn.call_on_exception(ctx.__exit__, None, None, None)
+            txn.call_after(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             self._simple_insert_txn(
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 56a0bde549..f83ff0454a 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
-import ujson as json
+import abc
+import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_account_data_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        account_max = self.get_max_account_data_stream_id()
+        self._account_data_stream_cache = StreamChangeCache(
+            "AccountDataAndTagsChangeCache", account_max,
+        )
+
+        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+    @abc.abstractmethod
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream ID for account data stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
 
     @cached()
     def get_account_data_for_user(self, user_id):
@@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
             for row in rows
         })
 
+    @cached(num_args=2)
     def get_account_data_for_room(self, user_id, room_id):
         """Get all the client account_data for a user for a room.
 
@@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
+    @cached(num_args=3, max_entries=5000)
+    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+        """Get the client account_data of given type for a user for a room.
+
+        Args:
+            user_id(str): The user to get the account_data for.
+            room_id(str): The room to get the account_data for.
+            account_data_type (str): The account data type to get.
+        Returns:
+            A deferred of the room account_data for that type, or None if
+            there isn't any set.
+        """
+        def get_account_data_for_room_and_type_txn(txn):
+            content_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="room_account_data",
+                keyvalues={
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "account_data_type": account_data_type,
+                },
+                retcol="content",
+                allow_none=True
+            )
+
+            return json.loads(content_json) if content_json else None
+
+        return self.runInteraction(
+            "get_account_data_for_room_and_type",
+            get_account_data_for_room_and_type_txn,
+        )
+
     def get_all_updated_account_data(self, last_global_id, last_room_id,
                                      current_id, limit):
         """Get all the client account_data that has changed on the server
@@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+            "m.ignored_user_list", ignorer_user_id,
+            on_invalidate=cache_context.invalidate,
+        )
+        if not ignored_account_data:
+            defer.returnValue(False)
+
+        defer.returnValue(
+            ignored_user_id in ignored_account_data.get("ignored_users", {})
+        )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    def __init__(self, db_conn, hs):
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        super(AccountDataStore, self).__init__(db_conn, hs)
+
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream id for the private user data stream
+
+        Returns:
+            A deferred int.
+        """
+        return self._account_data_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
         """Add some account_data to a room for a user.
@@ -251,6 +343,10 @@ class AccountDataStore(SQLBaseStore):
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
+            self.get_account_data_for_room.invalidate((user_id, room_id,))
+            self.get_account_data_for_room_and_type.prefill(
+                (user_id, room_id, account_data_type,), content,
+            )
 
         result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
@@ -321,16 +417,3 @@ class AccountDataStore(SQLBaseStore):
             "update_account_data_max_stream_id",
             _update,
         )
-
-    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
-    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
-        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list", ignorer_user_id,
-            on_invalidate=cache_context.invalidate,
-        )
-        if not ignored_account_data:
-            defer.returnValue(False)
-
-        defer.returnValue(
-            ignored_user_id in ignored_account_data.get("ignored_users", {})
-        )
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 79673b4273..12ea8a158c 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,10 +18,9 @@ import re
 import simplejson as json
 from twisted.internet import defer
 
-from synapse.api.constants import Membership
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.events import EventsWorkerStore
 from ._base import SQLBaseStore
 
 
@@ -46,17 +46,16 @@ def _make_exclusive_regex(services_cache):
     return exclusive_user_regex
 
 
-class ApplicationServiceStore(SQLBaseStore):
-
+class ApplicationServiceWorkerStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
-        super(ApplicationServiceStore, self).__init__(db_conn, hs)
-        self.hostname = hs.hostname
         self.services_cache = load_appservices(
             hs.hostname,
             hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
+        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+
     def get_app_services(self):
         return self.services_cache
 
@@ -112,83 +111,17 @@ class ApplicationServiceStore(SQLBaseStore):
                 return service
         return None
 
-    def get_app_service_rooms(self, service):
-        """Get a list of RoomsForUser for this application service.
-
-        Application services may be "interested" in lots of rooms depending on
-        the room ID, the room aliases, or the members in the room. This function
-        takes all of these into account and returns a list of RoomsForUser which
-        represent the entire list of room IDs that this application service
-        wants to know about.
-
-        Args:
-            service: The application service to get a room list for.
-        Returns:
-            A list of RoomsForUser.
-        """
-        return self.runInteraction(
-            "get_app_service_rooms",
-            self._get_app_service_rooms_txn,
-            service,
-        )
-
-    def _get_app_service_rooms_txn(self, txn, service):
-        # get all rooms matching the room ID regex.
-        room_entries = self._simple_select_list_txn(
-            txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
-        )
-        matching_room_list = set([
-            r["room_id"] for r in room_entries if
-            service.is_interested_in_room(r["room_id"])
-        ])
-
-        # resolve room IDs for matching room alias regex.
-        room_alias_mappings = self._simple_select_list_txn(
-            txn=txn, table="room_aliases", keyvalues=None,
-            retcols=["room_id", "room_alias"]
-        )
-        matching_room_list |= set([
-            r["room_id"] for r in room_alias_mappings if
-            service.is_interested_in_alias(r["room_alias"])
-        ])
-
-        # get all rooms for every user for this AS. This is scoped to users on
-        # this HS only.
-        user_list = self._simple_select_list_txn(
-            txn=txn, table="users", keyvalues=None, retcols=["name"]
-        )
-        user_list = [
-            u["name"] for u in user_list if
-            service.is_interested_in_user(u["name"])
-        ]
-        rooms_for_user_matching_user_id = set()  # RoomsForUser list
-        for user_id in user_list:
-            # FIXME: This assumes this store is linked with RoomMemberStore :(
-            rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
-                txn=txn,
-                user_id=user_id,
-                membership_list=[Membership.JOIN]
-            )
-            rooms_for_user_matching_user_id |= set(rooms_for_user)
-
-        # make RoomsForUser tuples for room ids and aliases which are not in the
-        # main rooms_for_user_list - e.g. they are rooms which do not have AS
-        # registered users in it.
-        known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
-        missing_rooms_for_user = [
-            RoomsForUser(r, service.sender, "join") for r in
-            matching_room_list if r not in known_room_ids
-        ]
-        rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
 
-        return rooms_for_user_matching_user_id
+class ApplicationServiceStore(ApplicationServiceWorkerStore):
+    # This is currently empty due to there not being any AS storage functions
+    # that can't be run on the workers. Since this may change in future, and
+    # to keep consistency with the other stores, we keep this empty class for
+    # now.
+    pass
 
 
-class ApplicationServiceTransactionStore(SQLBaseStore):
-
-    def __init__(self, db_conn, hs):
-        super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
-
+class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
+                                               EventsWorkerStore):
     @defer.inlineCallbacks
     def get_appservices_by_state(self, state):
         """Get a list of application services based on their state.
@@ -433,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
         events = yield self._get_events(event_ids)
 
         defer.returnValue((upper_bound, events))
+
+
+class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
+    # This is currently empty due to there not being any AS storage functions
+    # that can't be run on the workers. Since this may change in future, and
+    # to keep consistency with the other stores, we keep this empty class for
+    # now.
+    pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index c88759bf2c..8af325a9f5 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -19,7 +19,7 @@ from . import engines
 
 from twisted.internet import defer
 
-import ujson as json
+import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 548e795daf..a879e5bfc1 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-import ujson
+import simplejson
 
 from twisted.internet import defer
 
@@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             )
             rows = []
             for destination, edu in remote_messages_by_destination.items():
-                edu_json = ujson.dumps(edu)
+                edu_json = simplejson.dumps(edu)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
@@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                     " WHERE user_id = ?"
                 )
                 txn.execute(sql, (user_id,))
-                message_json = ujson.dumps(messages_by_device["*"])
+                message_json = simplejson.dumps(messages_by_device["*"])
                 for row in txn:
                     # Add the message for all devices for this user on this
                     # server.
@@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                     # Only insert into the local inbox if the device exists on
                     # this server
                     device = row[0]
-                    message_json = ujson.dumps(messages_by_device[device])
+                    message_json = simplejson.dumps(messages_by_device[device])
                     messages_json_for_user[device] = message_json
 
             if messages_json_for_user:
@@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             messages = []
             for row in txn:
                 stream_pos = row[0]
-                messages.append(ujson.loads(row[1]))
+                messages.append(simplejson.loads(row[1]))
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return (messages, stream_pos)
@@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             messages = []
             for row in txn:
                 stream_pos = row[0]
-                messages.append(ujson.loads(row[1]))
+                messages.append(simplejson.loads(row[1]))
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bd2effdf34..712106b83a 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import ujson as json
+import simplejson as json
 
 from twisted.internet import defer
 
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 79e7c540ad..d0c0059757 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple(
 )
 
 
-class DirectoryStore(SQLBaseStore):
-
+class DirectoryWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_association_from_room_alias(self, room_alias):
         """ Get's the room_id and server list for a given room_alias
@@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore):
             RoomAliasMapping(room_id, room_alias.to_string(), servers)
         )
 
+    def get_room_alias_creator(self, room_alias):
+        return self._simple_select_one_onecol(
+            table="room_aliases",
+            keyvalues={
+                "room_alias": room_alias,
+            },
+            retcol="creator",
+            desc="get_room_alias_creator",
+            allow_none=True
+        )
+
+    @cached(max_entries=5000)
+    def get_aliases_for_room(self, room_id):
+        return self._simple_select_onecol(
+            "room_aliases",
+            {"room_id": room_id},
+            "room_alias",
+            desc="get_aliases_for_room",
+        )
+
+
+class DirectoryStore(DirectoryWorkerStore):
     @defer.inlineCallbacks
     def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
         """ Creates an associatin between  a room alias and room_id/servers
@@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore):
             )
         defer.returnValue(ret)
 
-    def get_room_alias_creator(self, room_alias):
-        return self._simple_select_one_onecol(
-            table="room_aliases",
-            keyvalues={
-                "room_alias": room_alias,
-            },
-            retcol="creator",
-            desc="get_room_alias_creator",
-            allow_none=True
-        )
-
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
         room_id = yield self.runInteraction(
@@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore):
             room_alias,
         )
 
-        self.get_aliases_for_room.invalidate((room_id,))
         defer.returnValue(room_id)
 
     def _delete_room_alias_txn(self, txn, room_alias):
@@ -160,17 +169,12 @@ class DirectoryStore(SQLBaseStore):
             (room_alias.to_string(),)
         )
 
-        return room_id
-
-    @cached(max_entries=5000)
-    def get_aliases_for_room(self, room_id):
-        return self._simple_select_onecol(
-            "room_aliases",
-            {"room_id": room_id},
-            "room_alias",
-            desc="get_aliases_for_room",
+        self._invalidate_cache_and_stream(
+            txn, self.get_aliases_for_room, (room_id,)
         )
 
+        return room_id
+
     def update_aliases_for_room(self, old_room_id, new_room_id, creator):
         def _update_aliases_for_room_txn(txn):
             sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2cebb203c6..ff8538ddf8 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
-import ujson as json
+import simplejson as json
 
 from ._base import SQLBaseStore
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 55a05c59d5..00ee82d300 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -15,7 +15,10 @@
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+
 from synapse.api.errors import StoreError
 from synapse.util.caches.descriptors import cached
 from unpaddedbase64 import encode_base64
@@ -27,30 +30,8 @@ from Queue import PriorityQueue, Empty
 logger = logging.getLogger(__name__)
 
 
-class EventFederationStore(SQLBaseStore):
-    """ Responsible for storing and serving up the various graphs associated
-    with an event. Including the main event graph and the auth chains for an
-    event.
-
-    Also has methods for getting the front (latest) and back (oldest) edges
-    of the event graphs. These are used to generate the parents for new events
-    and backfilling from another server respectively.
-    """
-
-    EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
-
-    def __init__(self, db_conn, hs):
-        super(EventFederationStore, self).__init__(db_conn, hs)
-
-        self.register_background_update_handler(
-            self.EVENT_AUTH_STATE_ONLY,
-            self._background_delete_non_state_event_auth,
-        )
-
-        hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
-        )
-
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
+                                 SQLBaseStore):
     def get_auth_chain(self, event_ids, include_given=False):
         """Get auth events for given event_ids. The events *must* be state events.
 
@@ -228,88 +209,6 @@ class EventFederationStore(SQLBaseStore):
 
         return int(min_depth) if min_depth is not None else None
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
-        min_depth = self._get_min_depth_interaction(txn, room_id)
-
-        if min_depth and depth >= min_depth:
-            return
-
-        self._simple_upsert_txn(
-            txn,
-            table="room_depth",
-            keyvalues={
-                "room_id": room_id,
-            },
-            values={
-                "min_depth": depth,
-            },
-        )
-
-    def _handle_mult_prev_events(self, txn, events):
-        """
-        For the given event, update the event edges table and forward and
-        backward extremities tables.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="event_edges",
-            values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
-                for ev in events
-                for e_id, _ in ev.prev_events
-            ],
-        )
-
-        self._update_backward_extremeties(txn, events)
-
-    def _update_backward_extremeties(self, txn, events):
-        """Updates the event_backward_extremities tables based on the new/updated
-        events being persisted.
-
-        This is called for new events *and* for events that were outliers, but
-        are now being persisted as non-outliers.
-
-        Forward extremities are handled when we first start persisting the events.
-        """
-        events_by_room = {}
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
-        query = (
-            "INSERT INTO event_backward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-            " )"
-            " AND NOT EXISTS ("
-            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            " AND outlier = ?"
-            " )"
-        )
-
-        txn.executemany(query, [
-            (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
-            for ev in events for e_id, _ in ev.prev_events
-            if not ev.internal_metadata.is_outlier()
-        ])
-
-        query = (
-            "DELETE FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-        )
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id) for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ]
-        )
-
     def get_forward_extremeties_for_room(self, room_id, stream_ordering):
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
@@ -371,28 +270,6 @@ class EventFederationStore(SQLBaseStore):
             get_forward_extremeties_for_room_txn
         )
 
-    def _delete_old_forward_extrem_cache(self):
-        def _delete_old_forward_extrem_cache_txn(txn):
-            # Delete entries older than a month, while making sure we don't delete
-            # the only entries for a room.
-            sql = ("""
-                DELETE FROM stream_ordering_to_exterm
-                WHERE
-                room_id IN (
-                    SELECT room_id
-                    FROM stream_ordering_to_exterm
-                    WHERE stream_ordering > ?
-                ) AND stream_ordering < ?
-            """)
-            txn.execute(
-                sql,
-                (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
-            )
-        return self.runInteraction(
-            "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn
-        )
-
     def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
@@ -522,6 +399,135 @@ class EventFederationStore(SQLBaseStore):
 
         return event_results
 
+
+class EventFederationStore(EventFederationWorkerStore):
+    """ Responsible for storing and serving up the various graphs associated
+    with an event. Including the main event graph and the auth chains for an
+    event.
+
+    Also has methods for getting the front (latest) and back (oldest) edges
+    of the event graphs. These are used to generate the parents for new events
+    and backfilling from another server respectively.
+    """
+
+    EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
+    def __init__(self, db_conn, hs):
+        super(EventFederationStore, self).__init__(db_conn, hs)
+
+        self.register_background_update_handler(
+            self.EVENT_AUTH_STATE_ONLY,
+            self._background_delete_non_state_event_auth,
+        )
+
+        hs.get_clock().looping_call(
+            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+        )
+
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self._get_min_depth_interaction(txn, room_id)
+
+        if min_depth and depth >= min_depth:
+            return
+
+        self._simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={
+                "room_id": room_id,
+            },
+            values={
+                "min_depth": depth,
+            },
+        )
+
+    def _handle_mult_prev_events(self, txn, events):
+        """
+        For the given event, update the event edges table and forward and
+        backward extremities tables.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="event_edges",
+            values=[
+                {
+                    "event_id": ev.event_id,
+                    "prev_event_id": e_id,
+                    "room_id": ev.room_id,
+                    "is_state": False,
+                }
+                for ev in events
+                for e_id, _ in ev.prev_events
+            ],
+        )
+
+        self._update_backward_extremeties(txn, events)
+
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
+        events being persisted.
+
+        This is called for new events *and* for events that were outliers, but
+        are now being persisted as non-outliers.
+
+        Forward extremities are handled when we first start persisting the events.
+        """
+        events_by_room = {}
+        for ev in events:
+            events_by_room.setdefault(ev.room_id, []).append(ev)
+
+        query = (
+            "INSERT INTO event_backward_extremities (event_id, room_id)"
+            " SELECT ?, ? WHERE NOT EXISTS ("
+            " SELECT 1 FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+            " )"
+            " AND NOT EXISTS ("
+            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+            " AND outlier = ?"
+            " )"
+        )
+
+        txn.executemany(query, [
+            (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+            for ev in events for e_id, _ in ev.prev_events
+            if not ev.internal_metadata.is_outlier()
+        ])
+
+        query = (
+            "DELETE FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+        )
+        txn.executemany(
+            query,
+            [
+                (ev.event_id, ev.room_id) for ev in events
+                if not ev.internal_metadata.is_outlier()
+            ]
+        )
+
+    def _delete_old_forward_extrem_cache(self):
+        def _delete_old_forward_extrem_cache_txn(txn):
+            # Delete entries older than a month, while making sure we don't delete
+            # the only entries for a room.
+            sql = ("""
+                DELETE FROM stream_ordering_to_exterm
+                WHERE
+                room_id IN (
+                    SELECT room_id
+                    FROM stream_ordering_to_exterm
+                    WHERE stream_ordering > ?
+                ) AND stream_ordering < ?
+            """)
+            txn.execute(
+                sql,
+                (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+            )
+        return self.runInteraction(
+            "_delete_old_forward_extrem_cache",
+            _delete_old_forward_extrem_cache_txn
+        )
+
     def clean_room_for_join(self, room_id):
         return self.runInteraction(
             "clean_room_for_join",
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index f787431b7a..e78f8d0114 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, LoggingTransaction
 from twisted.internet import defer
 from synapse.util.async import sleep
 from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -21,7 +22,7 @@ from synapse.types import RoomStreamToken
 from .stream import lower_bound
 
 import logging
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -62,77 +63,28 @@ def _deserialize_action(actions, is_highlight):
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsStore(SQLBaseStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
+class EventPushActionsWorkerStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
-
-        self.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
-
-        self.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1"
-        )
-
-        self._doing_notif_rotation = False
-        self._rotate_notif_loop = self._clock.looping_call(
-            self._rotate_notifs, 30 * 60 * 1000
-        )
-
-    def _set_push_actions_for_event_and_users_txn(self, txn, event):
-        """
-        Args:
-            event: the event set actions for
-            tuples: list of tuples of (user_id, actions)
-        """
-
-        sql = """
-            INSERT INTO event_push_actions (
-                room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
-            )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
-            FROM event_push_actions_staging
-            WHERE event_id = ?
-        """
+        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
 
-        txn.execute(sql, (
-            event.room_id, event.internal_metadata.stream_ordering,
-            event.depth, event.event_id,
-        ))
+        # These get correctly set by _find_stream_orderings_for_times_txn
+        self.stream_ordering_month_ago = None
+        self.stream_ordering_day_ago = None
 
-        user_ids = self._simple_select_onecol_txn(
-            txn,
-            table="event_push_actions_staging",
-            keyvalues={
-                "event_id": event.event_id,
-            },
-            retcol="user_id",
+        cur = LoggingTransaction(
+            db_conn.cursor(),
+            name="_find_stream_orderings_for_times_txn",
+            database_engine=self.database_engine,
+            after_callbacks=[],
+            exception_callbacks=[],
         )
+        self._find_stream_orderings_for_times_txn(cur)
+        cur.close()
 
-        self._simple_delete_txn(
-            txn,
-            table="event_push_actions_staging",
-            keyvalues={
-                "event_id": event.event_id,
-            },
+        self.find_stream_orderings_looping_call = self._clock.looping_call(
+            self._find_stream_orderings_for_times, 10 * 60 * 1000
         )
 
-        for uid in user_ids:
-            txn.call_after(
-                self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                (event.room_id, uid,)
-            )
-
     @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
@@ -449,6 +401,280 @@ class EventPushActionsStore(SQLBaseStore):
         # Now return the first `limit`
         defer.returnValue(notifs[:limit])
 
+    def add_push_actions_to_staging(self, event_id, user_id_actions):
+        """Add the push actions for the event to the push action staging area.
+
+        Args:
+            event_id (str)
+            user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
+                user_id to list of push actions, where an action can either be
+                a string or dict.
+
+        Returns:
+            Deferred
+        """
+
+        if not user_id_actions:
+            return
+
+        # This is a helper function for generating the necessary tuple that
+        # can be used to inert into the `event_push_actions_staging` table.
+        def _gen_entry(user_id, actions):
+            is_highlight = 1 if _action_has_highlight(actions) else 0
+            return (
+                event_id,  # event_id column
+                user_id,  # user_id column
+                _serialize_action(actions, is_highlight),  # actions column
+                1,  # notif column
+                is_highlight,  # highlight column
+            )
+
+        def _add_push_actions_to_staging_txn(txn):
+            # We don't use _simple_insert_many here to avoid the overhead
+            # of generating lists of dicts.
+
+            sql = """
+                INSERT INTO event_push_actions_staging
+                    (event_id, user_id, actions, notif, highlight)
+                VALUES (?, ?, ?, ?, ?)
+            """
+
+            txn.executemany(sql, (
+                _gen_entry(user_id, actions)
+                for user_id, actions in user_id_actions.iteritems()
+            ))
+
+        return self.runInteraction(
+            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        )
+
+    def remove_push_actions_from_staging(self, event_id):
+        """Called if we failed to persist the event to ensure that stale push
+        actions don't build up in the DB
+
+        Args:
+            event_id (str)
+        """
+
+        return self._simple_delete(
+            table="event_push_actions_staging",
+            keyvalues={
+                "event_id": event_id,
+            },
+            desc="remove_push_actions_from_staging",
+        )
+
+    @defer.inlineCallbacks
+    def _find_stream_orderings_for_times(self):
+        yield self.runInteraction(
+            "_find_stream_orderings_for_times",
+            self._find_stream_orderings_for_times_txn
+        )
+
+    def _find_stream_orderings_for_times_txn(self, txn):
+        logger.info("Searching for stream ordering 1 month ago")
+        self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 month ago: it's %d",
+            self.stream_ordering_month_ago
+        )
+        logger.info("Searching for stream ordering 1 day ago")
+        self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 day ago: it's %d",
+            self.stream_ordering_day_ago
+        )
+
+    def find_first_stream_ordering_after_ts(self, ts):
+        """Gets the stream ordering corresponding to a given timestamp.
+
+        Specifically, finds the stream_ordering of the first event that was
+        received on or after the timestamp. This is done by a binary search on
+        the events table, since there is no index on received_ts, so is
+        relatively slow.
+
+        Args:
+            ts (int): timestamp in millis
+
+        Returns:
+            Deferred[int]: stream ordering of the first event received on/after
+                the timestamp
+        """
+        return self.runInteraction(
+            "_find_first_stream_ordering_after_ts_txn",
+            self._find_first_stream_ordering_after_ts_txn,
+            ts,
+        )
+
+    @staticmethod
+    def _find_first_stream_ordering_after_ts_txn(txn, ts):
+        """
+        Find the stream_ordering of the first event that was received on or
+        after a given timestamp. This is relatively slow as there is no index
+        on received_ts but we can then use this to delete push actions before
+        this.
+
+        received_ts must necessarily be in the same order as stream_ordering
+        and stream_ordering is indexed, so we manually binary search using
+        stream_ordering
+
+        Args:
+            txn (twisted.enterprise.adbapi.Transaction):
+            ts (int): timestamp to search for
+
+        Returns:
+            int: stream ordering
+        """
+        txn.execute("SELECT MAX(stream_ordering) FROM events")
+        max_stream_ordering = txn.fetchone()[0]
+
+        if max_stream_ordering is None:
+            return 0
+
+        # We want the first stream_ordering in which received_ts is greater
+        # than or equal to ts. Call this point X.
+        #
+        # We maintain the invariants:
+        #
+        #   range_start <= X <= range_end
+        #
+        range_start = 0
+        range_end = max_stream_ordering + 1
+
+        # Given a stream_ordering, look up the timestamp at that
+        # stream_ordering.
+        #
+        # The array may be sparse (we may be missing some stream_orderings).
+        # We treat the gaps as the same as having the same value as the
+        # preceding entry, because we will pick the lowest stream_ordering
+        # which satisfies our requirement of received_ts >= ts.
+        #
+        # For example, if our array of events indexed by stream_ordering is
+        # [10, <none>, 20], we should treat this as being equivalent to
+        # [10, 10, 20].
+        #
+        sql = (
+            "SELECT received_ts FROM events"
+            " WHERE stream_ordering <= ?"
+            " ORDER BY stream_ordering DESC"
+            " LIMIT 1"
+        )
+
+        while range_end - range_start > 0:
+            middle = (range_end + range_start) // 2
+            txn.execute(sql, (middle,))
+            row = txn.fetchone()
+            if row is None:
+                # no rows with stream_ordering<=middle
+                range_start = middle + 1
+                continue
+
+            middle_ts = row[0]
+            if ts > middle_ts:
+                # we got a timestamp lower than the one we were looking for.
+                # definitely need to look higher: X > middle.
+                range_start = middle + 1
+            else:
+                # we got a timestamp higher than (or the same as) the one we
+                # were looking for. We aren't yet sure about the point we
+                # looked up, but we can be sure that X <= middle.
+                range_end = middle
+
+        return range_end
+
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1"
+        )
+
+        self._doing_notif_rotation = False
+        self._rotate_notif_loop = self._clock.looping_call(
+            self._rotate_notifs, 30 * 60 * 1000
+        )
+
+    def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
+                                                  all_events_and_contexts):
+        """Handles moving push actions from staging table to main
+        event_push_actions table for all events in `events_and_contexts`.
+
+        Also ensures that all events in `all_events_and_contexts` are removed
+        from the push action staging area.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+        """
+
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
+            )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
+
+        if events_and_contexts:
+            txn.executemany(sql, (
+                (
+                    event.room_id, event.internal_metadata.stream_ordering,
+                    event.depth, event.event_id,
+                )
+                for event, _ in events_and_contexts
+            ))
+
+        for event, _ in events_and_contexts:
+            user_ids = self._simple_select_onecol_txn(
+                txn,
+                table="event_push_actions_staging",
+                keyvalues={
+                    "event_id": event.event_id,
+                },
+                retcol="user_id",
+            )
+
+            for uid in user_ids:
+                txn.call_after(
+                    self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                    (event.room_id, uid,)
+                )
+
+        # Now we delete the staging area for *all* events that were being
+        # persisted.
+        txn.executemany(
+            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+            (
+                (event.event_id,)
+                for event, _ in all_events_and_contexts
+            )
+        )
+
     @defer.inlineCallbacks
     def get_push_actions_for_user(self, user_id, before=None, limit=50,
                                   only_highlight=False):
@@ -568,69 +794,6 @@ class EventPushActionsStore(SQLBaseStore):
         """, (room_id, user_id, stream_ordering))
 
     @defer.inlineCallbacks
-    def _find_stream_orderings_for_times(self):
-        yield self.runInteraction(
-            "_find_stream_orderings_for_times",
-            self._find_stream_orderings_for_times_txn
-        )
-
-    def _find_stream_orderings_for_times_txn(self, txn):
-        logger.info("Searching for stream ordering 1 month ago")
-        self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
-            txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
-        )
-        logger.info(
-            "Found stream ordering 1 month ago: it's %d",
-            self.stream_ordering_month_ago
-        )
-        logger.info("Searching for stream ordering 1 day ago")
-        self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
-            txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
-        )
-        logger.info(
-            "Found stream ordering 1 day ago: it's %d",
-            self.stream_ordering_day_ago
-        )
-
-    def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
-        """
-        Find the stream_ordering of the first event that was received after
-        a given timestamp. This is relatively slow as there is no index on
-        received_ts but we can then use this to delete push actions before
-        this.
-
-        received_ts must necessarily be in the same order as stream_ordering
-        and stream_ordering is indexed, so we manually binary search using
-        stream_ordering
-        """
-        txn.execute("SELECT MAX(stream_ordering) FROM events")
-        max_stream_ordering = txn.fetchone()[0]
-
-        if max_stream_ordering is None:
-            return 0
-
-        range_start = 0
-        range_end = max_stream_ordering
-
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering > ?"
-            " ORDER BY stream_ordering"
-            " LIMIT 1"
-        )
-
-        while range_end - range_start > 1:
-            middle = int((range_end + range_start) / 2)
-            txn.execute(sql, (middle,))
-            middle_ts = txn.fetchone()[0]
-            if ts > middle_ts:
-                range_start = middle
-            else:
-                range_end = middle
-
-        return range_end
-
-    @defer.inlineCallbacks
     def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
@@ -755,50 +918,6 @@ class EventPushActionsStore(SQLBaseStore):
             (rotate_to_stream_ordering,)
         )
 
-    def add_push_actions_to_staging(self, event_id, user_id, actions):
-        """Add the push actions for the user and event to the push
-        action staging area.
-
-        Args:
-            event_id (str)
-            user_id (str)
-            actions (list[dict|str]): An action can either be a string or
-                dict.
-
-        Returns:
-            Deferred
-        """
-
-        is_highlight = 1 if _action_has_highlight(actions) else 0
-
-        return self._simple_insert(
-            table="event_push_actions_staging",
-            values={
-                "event_id": event_id,
-                "user_id": user_id,
-                "actions": _serialize_action(actions, is_highlight),
-                "notif": 1,
-                "highlight": is_highlight,
-            },
-            desc="add_push_actions_to_staging",
-        )
-
-    def remove_push_actions_from_staging(self, event_id):
-        """Called if we failed to persist the event to ensure that stale push
-        actions don't build up in the DB
-
-        Args:
-            event_id (str)
-        """
-
-        return self._simple_delete(
-            table="event_push_actions_staging",
-            keyvalues={
-                "event_id": event_id,
-            },
-            desc="remove_push_actions_from_staging",
-        )
-
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 73177e0bc2..ece5e6c41f 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,33 +13,29 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore
 
-from twisted.internet import defer, reactor
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
+import logging
+
+import simplejson as json
+from twisted.internet import defer
 
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
 
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.util.async import ObservableDeferred
+from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.logcontext import (
-    preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+    PreserveLoggingContext, make_deferred_yieldable,
 )
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.types import get_domain_from_id
-
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
-
 import synapse.metrics
 
-import logging
-import ujson as json
-
 # these are only included to make the type annotations work
 from synapse.events import EventBase    # noqa: F401
 from synapse.events.snapshot import EventContext   # noqa: F401
@@ -52,23 +49,25 @@ event_counter = metrics.register_counter(
     "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
 )
 
-
-def encode_json(json_object):
-    if USE_FROZEN_DICTS:
-        # ujson doesn't like frozen_dicts
-        return encode_canonical_json(json_object)
-    else:
-        return json.dumps(json_object, ensure_ascii=False)
+# The number of times we are recalculating the current state
+state_delta_counter = metrics.register_counter(
+    "state_delta",
+)
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = metrics.register_counter(
+    "state_delta_single_event",
+)
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = metrics.register_counter(
+    "state_delta_reuse_delta",
+)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+def encode_json(json_object):
+    return frozendict_json_encoder.encode(json_object)
 
 
 class _EventPeristenceQueue(object):
@@ -199,13 +198,12 @@ def _retry_on_integrity_error(func):
     return f
 
 
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
     def __init__(self, db_conn, hs):
         super(EventsStore, self).__init__(db_conn, hs)
-        self._clock = hs.get_clock()
         self.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
@@ -293,10 +291,11 @@ class EventsStore(SQLBaseStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            yield self._persist_events(
-                item.events_and_contexts,
-                backfilled=item.backfilled,
-            )
+            with Measure(self._clock, "persist_events"):
+                yield self._persist_events(
+                    item.events_and_contexts,
+                    backfilled=item.backfilled,
+                )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
@@ -378,7 +377,8 @@ class EventsStore(SQLBaseStore):
                                 room_id, ev_ctx_rm, latest_event_ids
                             )
 
-                            if new_latest_event_ids == set(latest_event_ids):
+                            latest_event_ids = set(latest_event_ids)
+                            if new_latest_event_ids == latest_event_ids:
                                 # No change in extremities, so no change in state
                                 continue
 
@@ -399,6 +399,26 @@ class EventsStore(SQLBaseStore):
                                 if all_single_prev_not_state:
                                     continue
 
+                            state_delta_counter.inc()
+                            if len(new_latest_event_ids) == 1:
+                                state_delta_single_event_counter.inc()
+
+                                # This is a fairly handwavey check to see if we could
+                                # have guessed what the delta would have been when
+                                # processing one of these events.
+                                # What we're interested in is if the latest extremities
+                                # were the same when we created the event as they are
+                                # now. When this server creates a new event (as opposed
+                                # to receiving it over federation) it will use the
+                                # forward extremities as the prev_events, so we can
+                                # guess this by looking at the prev_events and checking
+                                # if they match the current forward extremities.
+                                for ev, _ in ev_ctx_rm:
+                                    prev_event_ids = set(e for e, _ in ev.prev_events)
+                                    if latest_event_ids == prev_event_ids:
+                                        state_delta_reuse_delta_counter.inc()
+                                        break
+
                             logger.info(
                                 "Calculating state delta for room %s", room_id,
                             )
@@ -609,62 +629,6 @@ class EventsStore(SQLBaseStore):
 
         defer.returnValue((to_delete, to_insert))
 
-    @defer.inlineCallbacks
-    def get_event(self, event_id, check_redacted=True,
-                  get_prev_content=False, allow_rejected=False,
-                  allow_none=False):
-        """Get an event from the database by event_id.
-
-        Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
-                False throw an exception.
-
-        Returns:
-            Deferred : A FrozenEvent.
-        """
-        events = yield self._get_events(
-            [event_id],
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
-
-        if not events and not allow_none:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
-
-        defer.returnValue(events[0] if events else None)
-
-    @defer.inlineCallbacks
-    def get_events(self, event_ids, check_redacted=True,
-                   get_prev_content=False, allow_rejected=False):
-        """Get events from the database
-
-        Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-
-        Returns:
-            Deferred : Dict from event_id to event.
-        """
-        events = yield self._get_events(
-            event_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
-
-        defer.returnValue({e.event_id: e for e in events})
-
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
                             delete_existing=False, state_delta_for_room={},
@@ -693,6 +657,8 @@ class EventsStore(SQLBaseStore):
                 list of the event ids which are the forward extremities.
 
         """
+        all_events_and_contexts = events_and_contexts
+
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
         self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
@@ -755,6 +721,7 @@ class EventsStore(SQLBaseStore):
         self._update_metadata_tables_txn(
             txn,
             events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
             backfilled=backfilled,
         )
 
@@ -817,7 +784,7 @@ class EventsStore(SQLBaseStore):
 
                 for member in members_changed:
                     self._invalidate_cache_and_stream(
-                        txn, self.get_rooms_for_user, (member,)
+                        txn, self.get_rooms_for_user_with_stream_ordering, (member,)
                     )
 
                 for host in set(get_domain_from_id(u) for u in members_changed):
@@ -1152,26 +1119,33 @@ class EventsStore(SQLBaseStore):
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
-    def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+    def _update_metadata_tables_txn(self, txn, events_and_contexts,
+                                    all_events_and_contexts, backfilled):
         """Update all the miscellaneous tables for new events
 
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
             events_and_contexts (list[(EventBase, EventContext)]): events
                 we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
             backfilled (bool): True if the events were backfilled
         """
 
+        # Insert all the push actions into the event_push_actions table.
+        self._set_push_actions_for_event_and_users_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+        )
+
         if not events_and_contexts:
             # nothing to do here
             return
 
         for event, context in events_and_contexts:
-            # Insert all the push actions into the event_push_actions table.
-            self._set_push_actions_for_event_and_users_txn(
-                txn, event,
-            )
-
             if event.type == EventTypes.Redaction and event.redacts is not None:
                 # Remove the entries in the event_push_actions table for the
                 # redacted event.
@@ -1376,292 +1350,6 @@ class EventsStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False, allow_rejected=False):
-        if not event_ids:
-            defer.returnValue([])
-
-        event_id_list = event_ids
-        event_ids = set(event_ids)
-
-        event_entry_map = self._get_events_from_cache(
-            event_ids,
-            allow_rejected=allow_rejected,
-        )
-
-        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
-        if missing_events_ids:
-            missing_events = yield self._enqueue_events(
-                missing_events_ids,
-                check_redacted=check_redacted,
-                allow_rejected=allow_rejected,
-            )
-
-            event_entry_map.update(missing_events)
-
-        events = []
-        for event_id in event_id_list:
-            entry = event_entry_map.get(event_id, None)
-            if not entry:
-                continue
-
-            if allow_rejected or not entry.event.rejected_reason:
-                if check_redacted and entry.redacted_event:
-                    event = entry.redacted_event
-                else:
-                    event = entry.event
-
-                events.append(event)
-
-                if get_prev_content:
-                    if "replaces_state" in event.unsigned:
-                        prev = yield self.get_event(
-                            event.unsigned["replaces_state"],
-                            get_prev_content=False,
-                            allow_none=True,
-                        )
-                        if prev:
-                            event.unsigned = dict(event.unsigned)
-                            event.unsigned["prev_content"] = prev.content
-                            event.unsigned["prev_sender"] = prev.sender
-
-        defer.returnValue(events)
-
-    def _invalidate_get_event_cache(self, event_id):
-            self._get_event_cache.invalidate((event_id,))
-
-    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
-        """Fetch events from the caches
-
-        Args:
-            events (list(str)): list of event_ids to fetch
-            allow_rejected (bool): Whether to teturn events that were rejected
-            update_metrics (bool): Whether to update the cache hit ratio metrics
-
-        Returns:
-            dict of event_id -> _EventCacheEntry for each event_id in cache. If
-            allow_rejected is `False` then there will still be an entry but it
-            will be `None`
-        """
-        event_map = {}
-
-        for event_id in events:
-            ret = self._get_event_cache.get(
-                (event_id,), None,
-                update_metrics=update_metrics,
-            )
-            if not ret:
-                continue
-
-            if allow_rejected or not ret.event.rejected_reason:
-                event_map[event_id] = ret
-            else:
-                event_map[event_id] = None
-
-        return event_map
-
-    def _do_fetch(self, conn):
-        """Takes a database connection and waits for requests for events from
-        the _event_fetch_list queue.
-        """
-        event_list = []
-        i = 0
-        while True:
-            try:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
-                            self._event_fetch_ongoing -= 1
-                            return
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
-
-                event_id_lists = zip(*event_list)[0]
-                event_ids = [
-                    item for sublist in event_id_lists for item in sublist
-                ]
-
-                rows = self._new_transaction(
-                    conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
-                )
-
-                row_dict = {
-                    r["event_id"]: r
-                    for r in rows
-                }
-
-                # We only want to resolve deferreds from the main thread
-                def fire(lst, res):
-                    for ids, d in lst:
-                        if not d.called:
-                            try:
-                                with PreserveLoggingContext():
-                                    d.callback([
-                                        res[i]
-                                        for i in ids
-                                        if i in res
-                                    ])
-                            except Exception:
-                                logger.exception("Failed to callback")
-                with PreserveLoggingContext():
-                    reactor.callFromThread(fire, event_list, row_dict)
-            except Exception as e:
-                logger.exception("do_fetch")
-
-                # We only want to resolve deferreds from the main thread
-                def fire(evs):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(e)
-
-                if event_list:
-                    with PreserveLoggingContext():
-                        reactor.callFromThread(fire, event_list)
-
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
-        """Fetches events from the database using the _event_fetch_list. This
-        allows batch and bulk fetching of events - it allows us to fetch events
-        without having to create a new transaction for each request for events.
-        """
-        if not events:
-            defer.returnValue({})
-
-        events_d = defer.Deferred()
-        with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, events_d)
-            )
-
-            self._event_fetch_lock.notify()
-
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            with PreserveLoggingContext():
-                self.runWithConnection(
-                    self._do_fetch
-                )
-
-        logger.debug("Loading %d events", len(events))
-        with PreserveLoggingContext():
-            rows = yield events_d
-        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
-        if not allow_rejected:
-            rows[:] = [r for r in rows if not r["rejects"]]
-
-        res = yield make_deferred_yieldable(defer.gatherResults(
-            [
-                preserve_fn(self._get_event_from_row)(
-                    row["internal_metadata"], row["json"], row["redacts"],
-                    rejected_reason=row["rejects"],
-                )
-                for row in rows
-            ],
-            consumeErrors=True
-        ))
-
-        defer.returnValue({
-            e.event.event_id: e
-            for e in res if e
-        })
-
-    def _fetch_event_rows(self, txn, events):
-        rows = []
-        N = 200
-        for i in range(1 + len(events) / N):
-            evs = events[i * N:(i + 1) * N]
-            if not evs:
-                break
-
-            sql = (
-                "SELECT "
-                " e.event_id as event_id, "
-                " e.internal_metadata,"
-                " e.json,"
-                " r.redacts as redacts,"
-                " rej.event_id as rejects "
-                " FROM event_json as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
-                " WHERE e.event_id IN (%s)"
-            ) % (",".join(["?"] * len(evs)),)
-
-            txn.execute(sql, evs)
-            rows.extend(self.cursor_to_dict(txn))
-
-        return rows
-
-    @defer.inlineCallbacks
-    def _get_event_from_row(self, internal_metadata, js, redacted,
-                            rejected_reason=None):
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if rejected_reason:
-                rejected_reason = yield self._simple_select_one_onecol(
-                    table="rejections",
-                    keyvalues={"event_id": rejected_reason},
-                    retcol="reason",
-                    desc="_get_event_from_row_rejected_reason",
-                )
-
-            original_ev = FrozenEvent(
-                d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = None
-            if redacted:
-                redacted_event = prune_event(original_ev)
-
-                redaction_id = yield self._simple_select_one_onecol(
-                    table="redactions",
-                    keyvalues={"redacts": redacted_event.event_id},
-                    retcol="event_id",
-                    desc="_get_event_from_row_redactions",
-                )
-
-                redacted_event.unsigned["redacted_by"] = redaction_id
-                # Get the redaction event.
-
-                because = yield self.get_event(
-                    redaction_id,
-                    check_redacted=False,
-                    allow_none=True,
-                )
-
-                if because:
-                    # It's fine to do add the event directly, since get_pdu_json
-                    # will serialise this field correctly
-                    redacted_event.unsigned["redacted_because"] = because
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev,
-                redacted_event=redacted_event,
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        defer.returnValue(cache_entry)
-
-    @defer.inlineCallbacks
     def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
@@ -2375,7 +2063,7 @@ class EventsStore(SQLBaseStore):
         to_2, so_2 = yield self._get_event_ordering(event_id2)
         defer.returnValue((to_1, so_1) > (to_2, so_2))
 
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks(max_entries=5000)
     def _get_event_ordering(self, event_id):
         res = yield self._simple_select_one(
             table="events",
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..2e23dd78ba
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,395 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+from twisted.internet import defer, reactor
+
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import (
+    preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+)
+from synapse.util.metrics import Measure
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+import logging
+import simplejson as json
+
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+
+    @defer.inlineCallbacks
+    def get_event(self, event_id, check_redacted=True,
+                  get_prev_content=False, allow_rejected=False,
+                  allow_none=False):
+        """Get an event from the database by event_id.
+
+        Args:
+            event_id (str): The event_id of the event to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+            allow_none (bool): If True, return None if no event found, if
+                False throw an exception.
+
+        Returns:
+            Deferred : A FrozenEvent.
+        """
+        events = yield self._get_events(
+            [event_id],
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        if not events and not allow_none:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue(events[0] if events else None)
+
+    @defer.inlineCallbacks
+    def get_events(self, event_ids, check_redacted=True,
+                   get_prev_content=False, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred : Dict from event_id to event.
+        """
+        events = yield self._get_events(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        defer.returnValue({e.event_id: e for e in events})
+
+    @defer.inlineCallbacks
+    def _get_events(self, event_ids, check_redacted=True,
+                    get_prev_content=False, allow_rejected=False):
+        if not event_ids:
+            defer.returnValue([])
+
+        event_id_list = event_ids
+        event_ids = set(event_ids)
+
+        event_entry_map = self._get_events_from_cache(
+            event_ids,
+            allow_rejected=allow_rejected,
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+        if missing_events_ids:
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                allow_rejected=allow_rejected,
+            )
+
+            event_entry_map.update(missing_events)
+
+        events = []
+        for event_id in event_id_list:
+            entry = event_entry_map.get(event_id, None)
+            if not entry:
+                continue
+
+            if allow_rejected or not entry.event.rejected_reason:
+                if check_redacted and entry.redacted_event:
+                    event = entry.redacted_event
+                else:
+                    event = entry.event
+
+                events.append(event)
+
+                if get_prev_content:
+                    if "replaces_state" in event.unsigned:
+                        prev = yield self.get_event(
+                            event.unsigned["replaces_state"],
+                            get_prev_content=False,
+                            allow_none=True,
+                        )
+                        if prev:
+                            event.unsigned = dict(event.unsigned)
+                            event.unsigned["prev_content"] = prev.content
+                            event.unsigned["prev_sender"] = prev.sender
+
+        defer.returnValue(events)
+
+    def _invalidate_get_event_cache(self, event_id):
+            self._get_event_cache.invalidate((event_id,))
+
+    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+        """Fetch events from the caches
+
+        Args:
+            events (list(str)): list of event_ids to fetch
+            allow_rejected (bool): Whether to teturn events that were rejected
+            update_metrics (bool): Whether to update the cache hit ratio metrics
+
+        Returns:
+            dict of event_id -> _EventCacheEntry for each event_id in cache. If
+            allow_rejected is `False` then there will still be an entry but it
+            will be `None`
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = self._get_event_cache.get(
+                (event_id,), None,
+                update_metrics=update_metrics,
+            )
+            if not ret:
+                continue
+
+            if allow_rejected or not ret.event.rejected_reason:
+                event_map[event_id] = ret
+            else:
+                event_map[event_id] = None
+
+        return event_map
+
+    def _do_fetch(self, conn):
+        """Takes a database connection and waits for requests for events from
+        the _event_fetch_list queue.
+        """
+        event_list = []
+        i = 0
+        while True:
+            try:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        single_threaded = self.database_engine.single_threaded
+                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                            self._event_fetch_ongoing -= 1
+                            return
+                        else:
+                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                            i += 1
+                            continue
+                    i = 0
+
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+
+                rows = self._new_transaction(
+                    conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+                )
+
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                # We only want to resolve deferreds from the main thread
+                def fire(lst, res):
+                    for ids, d in lst:
+                        if not d.called:
+                            try:
+                                with PreserveLoggingContext():
+                                    d.callback([
+                                        res[i]
+                                        for i in ids
+                                        if i in res
+                                    ])
+                            except Exception:
+                                logger.exception("Failed to callback")
+                with PreserveLoggingContext():
+                    reactor.callFromThread(fire, event_list, row_dict)
+            except Exception as e:
+                logger.exception("do_fetch")
+
+                # We only want to resolve deferreds from the main thread
+                def fire(evs):
+                    for _, d in evs:
+                        if not d.called:
+                            with PreserveLoggingContext():
+                                d.errback(e)
+
+                if event_list:
+                    with PreserveLoggingContext():
+                        reactor.callFromThread(fire, event_list)
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+        """Fetches events from the database using the _event_fetch_list. This
+        allows batch and bulk fetching of events - it allows us to fetch events
+        without having to create a new transaction for each request for events.
+        """
+        if not events:
+            defer.returnValue({})
+
+        events_d = defer.Deferred()
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, events_d)
+            )
+
+            self._event_fetch_lock.notify()
+
+            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+                self._event_fetch_ongoing += 1
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            with PreserveLoggingContext():
+                self.runWithConnection(
+                    self._do_fetch
+                )
+
+        logger.debug("Loading %d events", len(events))
+        with PreserveLoggingContext():
+            rows = yield events_d
+        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+        if not allow_rejected:
+            rows[:] = [r for r in rows if not r["rejects"]]
+
+        res = yield make_deferred_yieldable(defer.gatherResults(
+            [
+                preserve_fn(self._get_event_from_row)(
+                    row["internal_metadata"], row["json"], row["redacts"],
+                    rejected_reason=row["rejects"],
+                )
+                for row in rows
+            ],
+            consumeErrors=True
+        ))
+
+        defer.returnValue({
+            e.event.event_id: e
+            for e in res if e
+        })
+
+    def _fetch_event_rows(self, txn, events):
+        rows = []
+        N = 200
+        for i in range(1 + len(events) / N):
+            evs = events[i * N:(i + 1) * N]
+            if not evs:
+                break
+
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " e.internal_metadata,"
+                " e.json,"
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"] * len(evs)),)
+
+            txn.execute(sql, evs)
+            rows.extend(self.cursor_to_dict(txn))
+
+        return rows
+
+    @defer.inlineCallbacks
+    def _get_event_from_row(self, internal_metadata, js, redacted,
+                            rejected_reason=None):
+        with Measure(self._clock, "_get_event_from_row"):
+            d = json.loads(js)
+            internal_metadata = json.loads(internal_metadata)
+
+            if rejected_reason:
+                rejected_reason = yield self._simple_select_one_onecol(
+                    table="rejections",
+                    keyvalues={"event_id": rejected_reason},
+                    retcol="reason",
+                    desc="_get_event_from_row_rejected_reason",
+                )
+
+            original_ev = FrozenEvent(
+                d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
+            )
+
+            redacted_event = None
+            if redacted:
+                redacted_event = prune_event(original_ev)
+
+                redaction_id = yield self._simple_select_one_onecol(
+                    table="redactions",
+                    keyvalues={"redacts": redacted_event.event_id},
+                    retcol="event_id",
+                    desc="_get_event_from_row_redactions",
+                )
+
+                redacted_event.unsigned["redacted_by"] = redaction_id
+                # Get the redaction event.
+
+                because = yield self.get_event(
+                    redaction_id,
+                    check_redacted=False,
+                    allow_none=True,
+                )
+
+                if because:
+                    # It's fine to do add the event directly, since get_pdu_json
+                    # will serialise this field correctly
+                    redacted_event.unsigned["redacted_because"] = because
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev,
+                redacted_event=redacted_event,
+            )
+
+            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+        defer.returnValue(cache_entry)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index 8fde1aab8e..d03858234b 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
 
 from ._base import SQLBaseStore
 
-import ujson as json
+import simplejson as json
 
 
 # The category ID for the "default" category. We don't store as null in the
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index ec02e73bc2..8612bd5ecc 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -21,14 +21,7 @@ from synapse.api.errors import StoreError
 from ._base import SQLBaseStore
 
 
-class ProfileStore(SQLBaseStore):
-    def create_profile(self, user_localpart):
-        return self._simple_insert(
-            table="profiles",
-            values={"user_id": user_localpart},
-            desc="create_profile",
-        )
-
+class ProfileWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_profileinfo(self, user_localpart):
         try:
@@ -61,14 +54,6 @@ class ProfileStore(SQLBaseStore):
             desc="get_profile_displayname",
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"displayname": new_displayname},
-            desc="set_profile_displayname",
-        )
-
     def get_profile_avatar_url(self, user_localpart):
         return self._simple_select_one_onecol(
             table="profiles",
@@ -77,14 +62,6 @@ class ProfileStore(SQLBaseStore):
             desc="get_profile_avatar_url",
         )
 
-    def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"avatar_url": new_avatar_url},
-            desc="set_profile_avatar_url",
-        )
-
     def get_from_remote_profile_cache(self, user_id):
         return self._simple_select_one(
             table="remote_profile_cache",
@@ -94,6 +71,31 @@ class ProfileStore(SQLBaseStore):
             desc="get_from_remote_profile_cache",
         )
 
+
+class ProfileStore(ProfileWorkerStore):
+    def create_profile(self, user_localpart):
+        return self._simple_insert(
+            table="profiles",
+            values={"user_id": user_localpart},
+            desc="create_profile",
+        )
+
+    def set_profile_displayname(self, user_localpart, new_displayname):
+        return self._simple_update_one(
+            table="profiles",
+            keyvalues={"user_id": user_localpart},
+            updatevalues={"displayname": new_displayname},
+            desc="set_profile_displayname",
+        )
+
+    def set_profile_avatar_url(self, user_localpart, new_avatar_url):
+        return self._simple_update_one(
+            table="profiles",
+            keyvalues={"user_id": user_localpart},
+            updatevalues={"avatar_url": new_avatar_url},
+            desc="set_profile_avatar_url",
+        )
+
     def add_remote_profile_cache(self, user_id, displayname, avatar_url):
         """Ensure we are caching the remote user's profiles.
 
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..04a0b59a39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,11 +15,17 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.push.baserules import list_with_base_rules
 from synapse.api.constants import EventTypes
 from twisted.internet import defer
 
+import abc
 import logging
 import simplejson as json
 
@@ -48,7 +55,43 @@ def _load_rules(rawrules, enabled_map):
     return rules
 
 
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+                           ReceiptsWorkerStore,
+                           PusherWorkerStore,
+                           RoomMemberWorkerStore,
+                           SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
@@ -89,6 +132,22 @@ class PushRuleStore(SQLBaseStore):
             r['rule_id']: False if r['enabled'] == 0 else True for r in results
         })
 
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
     @cachedList(cached_method_name="get_push_rules_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules(self, user_ids):
@@ -228,6 +287,8 @@ class PushRuleStore(SQLBaseStore):
             results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
         defer.returnValue(results)
 
+
+class PushRuleStore(PushRulesWorkerStore):
     @defer.inlineCallbacks
     def add_push_rule(
         self, user_id, rule_id, priority_class, conditions, actions,
@@ -526,21 +587,8 @@ class PushRuleStore(SQLBaseStore):
         room stream ordering it corresponds to."""
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
 
 
 class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 3d8b4d5d5b..307660b99a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
 logger = logging.getLogger(__name__)
 
 
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
     def _decode_pushers_rows(self, rows):
         for r in rows:
             dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
-    def get_pushers_stream_token(self):
-        return self._pushers_id_gen.get_current_token()
-
     def get_all_updated_pushers(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed(([], []))
@@ -198,6 +196,11 @@ class PusherStore(SQLBaseStore):
 
         defer.returnValue(result)
 
+
+class PusherStore(PusherWorkerStore):
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_pusher(self, user_id, access_token, kind, app_id,
                    app_display_name, device_display_name,
@@ -230,14 +233,18 @@ class PusherStore(SQLBaseStore):
             )
 
             if newly_inserted:
-                # get_if_user_has_pusher only cares if the user has
-                # at least *one* pusher.
-                self.get_if_user_has_pusher.invalidate(user_id,)
+                self.runInteraction(
+                    "add_pusher",
+                    self._invalidate_cache_and_stream,
+                    self.get_if_user_has_pusher, (user_id,)
+                )
 
     @defer.inlineCallbacks
     def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
         def delete_pusher_txn(txn, stream_id):
-            txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_if_user_has_pusher, (user_id,)
+            )
 
             self._simple_delete_one_txn(
                 txn,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 12b3cc7f5f..63997ed449 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,51 +15,50 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from twisted.internet import defer
 
+import abc
 import logging
-import ujson as json
+import simplejson as json
 
 
 logger = logging.getLogger(__name__)
 
 
-class ReceiptsStore(SQLBaseStore):
+class ReceiptsWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_receipt_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
     def __init__(self, db_conn, hs):
-        super(ReceiptsStore, self).__init__(db_conn, hs)
+        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
+    @abc.abstractmethod
+    def get_max_receipt_stream_id(self):
+        """Get the current max stream ID for receipts stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
         defer.returnValue(set(r['user_id'] for r in receipts))
 
-    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
-                                                    user_id):
-        if receipt_type != "m.read":
-            return
-
-        # Returns an ObservableDeferred
-        res = self.get_users_with_read_receipts_in_room.cache.get(
-            room_id, None, update_metrics=False,
-        )
-
-        if res:
-            if isinstance(res, defer.Deferred) and res.called:
-                res = res.result
-            if user_id in res:
-                # We'd only be adding to the set, so no point invalidating if the
-                # user is already there
-                return
-
-        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
         return self._simple_select_list(
@@ -270,6 +270,59 @@ class ReceiptsStore(SQLBaseStore):
         }
         defer.returnValue(results)
 
+    def get_all_updated_receipts(self, last_id, current_id, limit=None):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_receipts_txn(txn):
+            sql = (
+                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+                " FROM receipts_linearized"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+            )
+            args = [last_id, current_id]
+            if limit is not None:
+                sql += " LIMIT ?"
+                args.append(limit)
+            txn.execute(sql, args)
+
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_updated_receipts", get_all_updated_receipts_txn
+        )
+
+    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+                                                    user_id):
+        if receipt_type != "m.read":
+            return
+
+        # Returns an ObservableDeferred
+        res = self.get_users_with_read_receipts_in_room.cache.get(
+            room_id, None, update_metrics=False,
+        )
+
+        if res:
+            if isinstance(res, defer.Deferred) and res.called:
+                res = res.result
+            if user_id in res:
+                # We'd only be adding to the set, so no point invalidating if the
+                # user is already there
+                return
+
+        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    def __init__(self, db_conn, hs):
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+
+        super(ReceiptsStore, self).__init__(db_conn, hs)
+
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
@@ -457,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-    def get_all_updated_receipts(self, last_id, current_id, limit=None):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_receipts_txn(txn):
-            sql = (
-                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
-                " FROM receipts_linearized"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-            )
-            args = [last_id, current_id]
-            if limit is not None:
-                sql += " LIMIT ?"
-                args.append(limit)
-            txn.execute(sql, args)
-
-            return txn.fetchall()
-        return self.runInteraction(
-            "get_all_updated_receipts", get_all_updated_receipts_txn
-        )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 95f75d6df1..6b557ca0cf 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,10 +19,70 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError, Codes
 from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 
-class RegistrationStore(background_updates.BackgroundUpdateStore):
+class RegistrationWorkerStore(SQLBaseStore):
+    @cached()
+    def get_user_by_id(self, user_id):
+        return self._simple_select_one(
+            table="users",
+            keyvalues={
+                "name": user_id,
+            },
+            retcols=["name", "password_hash", "is_guest"],
+            allow_none=True,
+            desc="get_user_by_id",
+        )
+
+    @cached()
+    def get_user_by_access_token(self, token):
+        """Get a user from the given access token.
+
+        Args:
+            token (str): The access token of a user.
+        Returns:
+            defer.Deferred: None, if the token did not match, otherwise dict
+                including the keys `name`, `is_guest`, `device_id`, `token_id`.
+        """
+        return self.runInteraction(
+            "get_user_by_access_token",
+            self._query_for_auth,
+            token
+        )
+
+    @defer.inlineCallbacks
+    def is_server_admin(self, user):
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            retcol="admin",
+            allow_none=True,
+            desc="is_server_admin",
+        )
+
+        defer.returnValue(res if res else False)
+
+    def _query_for_auth(self, txn, token):
+        sql = (
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            " access_tokens.device_id"
+            " FROM users"
+            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
+            " WHERE token = ?"
+        )
+
+        txn.execute(sql, (token,))
+        rows = self.cursor_to_dict(txn)
+        if rows:
+            return rows[0]
+
+        return None
+
+
+class RegistrationStore(RegistrationWorkerStore,
+                        background_updates.BackgroundUpdateStore):
 
     def __init__(self, db_conn, hs):
         super(RegistrationStore, self).__init__(db_conn, hs)
@@ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    @cached()
-    def get_user_by_id(self, user_id):
-        return self._simple_select_one(
-            table="users",
-            keyvalues={
-                "name": user_id,
-            },
-            retcols=["name", "password_hash", "is_guest"],
-            allow_none=True,
-            desc="get_user_by_id",
-        )
-
     def get_users_by_id_case_insensitive(self, user_id):
         """Gets users that match user_id case insensitively.
         Returns a mapping of user_id -> password_hash.
@@ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         return self.runInteraction("delete_access_token", f)
 
-    @cached()
-    def get_user_by_access_token(self, token):
-        """Get a user from the given access token.
-
-        Args:
-            token (str): The access token of a user.
-        Returns:
-            defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`.
-        """
-        return self.runInteraction(
-            "get_user_by_access_token",
-            self._query_for_auth,
-            token
-        )
-
-    @defer.inlineCallbacks
-    def is_server_admin(self, user):
-        res = yield self._simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user.to_string()},
-            retcol="admin",
-            allow_none=True,
-            desc="is_server_admin",
-        )
-
-        defer.returnValue(res if res else False)
-
     @cachedInlineCallbacks()
     def is_guest(self, user_id):
         res = yield self._simple_select_one_onecol(
@@ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         defer.returnValue(res if res else False)
 
-    def _query_for_auth(self, txn, token):
-        sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
-            " access_tokens.device_id"
-            " FROM users"
-            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
-            " WHERE token = ?"
-        )
-
-        txn.execute(sql, (token,))
-        rows = self.cursor_to_dict(txn)
-        if rows:
-            return rows[0]
-
-        return None
-
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
         yield self._simple_upsert("user_threepids", {
@@ -456,14 +460,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         """
         def _find_next_generated_user_id(txn):
             txn.execute("SELECT name FROM users")
-            rows = self.cursor_to_dict(txn)
 
             regex = re.compile("^@(\d+):")
 
             found = set()
 
-            for r in rows:
-                user_id = r["name"]
+            for user_id, in txn:
                 match = regex.search(user_id)
                 if match:
                     found.add(int(match.group(1)))
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index fff6652e05..908551d6d9 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,12 +16,13 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.search import SearchStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 import collections
 import logging
-import ujson as json
+import simplejson as json
 import re
 
 logger = logging.getLogger(__name__)
@@ -38,7 +39,138 @@ RatelimitOverride = collections.namedtuple(
 )
 
 
-class RoomStore(SearchStore):
+class RoomWorkerStore(SQLBaseStore):
+    def get_public_room_ids(self):
+        return self._simple_select_onecol(
+            table="rooms",
+            keyvalues={
+                "is_public": True,
+            },
+            retcol="room_id",
+            desc="get_public_room_ids",
+        )
+
+    @cached(num_args=2, max_entries=100)
+    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+        """Get pulbic rooms for a particular list, or across all lists.
+
+        Args:
+            stream_id (int)
+            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+                means the main list, None means all lsits.
+        """
+        return self.runInteraction(
+            "get_public_room_ids_at_stream_id",
+            self.get_public_room_ids_at_stream_id_txn,
+            stream_id, network_tuple=network_tuple
+        )
+
+    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+                                             network_tuple):
+        return {
+            rm
+            for rm, vis in self.get_published_at_stream_id_txn(
+                txn, stream_id, network_tuple=network_tuple
+            ).items()
+            if vis
+        }
+
+    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+        if network_tuple:
+            # We want to get from a particular list. No aggregation required.
+
+            sql = ("""
+                SELECT room_id, visibility FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT room_id, max(stream_id) AS stream_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ? %s
+                    GROUP BY room_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            if network_tuple.appservice_id is not None:
+                txn.execute(
+                    sql % ("AND appservice_id = ? AND network_id = ?",),
+                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+                )
+            else:
+                txn.execute(
+                    sql % ("AND appservice_id IS NULL",),
+                    (stream_id,)
+                )
+            return dict(txn)
+        else:
+            # We want to get from all lists, so we need to aggregate the results
+
+            logger.info("Executing full list")
+
+            sql = ("""
+                SELECT room_id, visibility
+                FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT
+                        room_id, max(stream_id) AS stream_id, appservice_id,
+                        network_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ?
+                    GROUP BY room_id, appservice_id, network_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            txn.execute(
+                sql,
+                (stream_id,)
+            )
+
+            results = {}
+            # A room is visible if its visible on any list.
+            for room_id, visibility in txn:
+                results[room_id] = bool(visibility) or results.get(room_id, False)
+
+            return results
+
+    def get_public_room_changes(self, prev_stream_id, new_stream_id,
+                                network_tuple):
+        def get_public_room_changes_txn(txn):
+            then_rooms = self.get_public_room_ids_at_stream_id_txn(
+                txn, prev_stream_id, network_tuple
+            )
+
+            now_rooms_dict = self.get_published_at_stream_id_txn(
+                txn, new_stream_id, network_tuple
+            )
+
+            now_rooms_visible = set(
+                rm for rm, vis in now_rooms_dict.items() if vis
+            )
+            now_rooms_not_visible = set(
+                rm for rm, vis in now_rooms_dict.items() if not vis
+            )
+
+            newly_visible = now_rooms_visible - then_rooms
+            newly_unpublished = now_rooms_not_visible & then_rooms
+
+            return newly_visible, newly_unpublished
+
+        return self.runInteraction(
+            "get_public_room_changes", get_public_room_changes_txn
+        )
+
+    @cached(max_entries=10000)
+    def is_room_blocked(self, room_id):
+        return self._simple_select_one_onecol(
+            table="blocked_rooms",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="1",
+            allow_none=True,
+            desc="is_room_blocked",
+        )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
 
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
@@ -225,16 +357,6 @@ class RoomStore(SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_public_room_ids(self):
-        return self._simple_select_onecol(
-            table="rooms",
-            keyvalues={
-                "is_public": True,
-            },
-            retcol="room_id",
-            desc="get_public_room_ids",
-        )
-
     def get_room_count(self):
         """Retrieve a list of all rooms
         """
@@ -326,113 +448,6 @@ class RoomStore(SearchStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    @cached(num_args=2, max_entries=100)
-    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
-        """Get pulbic rooms for a particular list, or across all lists.
-
-        Args:
-            stream_id (int)
-            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
-                means the main list, None means all lsits.
-        """
-        return self.runInteraction(
-            "get_public_room_ids_at_stream_id",
-            self.get_public_room_ids_at_stream_id_txn,
-            stream_id, network_tuple=network_tuple
-        )
-
-    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
-                                             network_tuple):
-        return {
-            rm
-            for rm, vis in self.get_published_at_stream_id_txn(
-                txn, stream_id, network_tuple=network_tuple
-            ).items()
-            if vis
-        }
-
-    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
-        if network_tuple:
-            # We want to get from a particular list. No aggregation required.
-
-            sql = ("""
-                SELECT room_id, visibility FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT room_id, max(stream_id) AS stream_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ? %s
-                    GROUP BY room_id
-                ) grouped USING (room_id, stream_id)
-            """)
-
-            if network_tuple.appservice_id is not None:
-                txn.execute(
-                    sql % ("AND appservice_id = ? AND network_id = ?",),
-                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
-                )
-            else:
-                txn.execute(
-                    sql % ("AND appservice_id IS NULL",),
-                    (stream_id,)
-                )
-            return dict(txn)
-        else:
-            # We want to get from all lists, so we need to aggregate the results
-
-            logger.info("Executing full list")
-
-            sql = ("""
-                SELECT room_id, visibility
-                FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT
-                        room_id, max(stream_id) AS stream_id, appservice_id,
-                        network_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ?
-                    GROUP BY room_id, appservice_id, network_id
-                ) grouped USING (room_id, stream_id)
-            """)
-
-            txn.execute(
-                sql,
-                (stream_id,)
-            )
-
-            results = {}
-            # A room is visible if its visible on any list.
-            for room_id, visibility in txn:
-                results[room_id] = bool(visibility) or results.get(room_id, False)
-
-            return results
-
-    def get_public_room_changes(self, prev_stream_id, new_stream_id,
-                                network_tuple):
-        def get_public_room_changes_txn(txn):
-            then_rooms = self.get_public_room_ids_at_stream_id_txn(
-                txn, prev_stream_id, network_tuple
-            )
-
-            now_rooms_dict = self.get_published_at_stream_id_txn(
-                txn, new_stream_id, network_tuple
-            )
-
-            now_rooms_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if vis
-            )
-            now_rooms_not_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if not vis
-            )
-
-            newly_visible = now_rooms_visible - then_rooms
-            newly_unpublished = now_rooms_not_visible & then_rooms
-
-            return newly_visible, newly_unpublished
-
-        return self.runInteraction(
-            "get_public_room_changes", get_public_room_changes_txn
-        )
-
     def get_all_new_public_rooms(self, prev_id, current_id, limit):
         def get_all_new_public_rooms(txn):
             sql = ("""
@@ -482,18 +497,6 @@ class RoomStore(SearchStore):
         else:
             defer.returnValue(None)
 
-    @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self._simple_select_one_onecol(
-            table="blocked_rooms",
-            keyvalues={
-                "room_id": room_id,
-            },
-            retcol="1",
-            allow_none=True,
-            desc="is_room_blocked",
-        )
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         yield self._simple_insert(
@@ -504,7 +507,11 @@ class RoomStore(SearchStore):
             },
             desc="block_room",
         )
-        self.is_room_blocked.invalidate((room_id,))
+        yield self.runInteraction(
+            "block_room_invalidation",
+            self._invalidate_cache_and_stream,
+            self.is_room_blocked, (room_id,),
+        )
 
     def get_media_mxcs_in_room(self, room_id):
         """Retrieves all the local and remote media MXC URIs in a given room
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 3e77fd3901..d662d1cfc0 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,7 +18,7 @@ from twisted.internet import defer
 
 from collections import namedtuple
 
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
 from synapse.util.async import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -27,7 +28,7 @@ from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
 
 import logging
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -37,6 +38,11 @@ RoomsForUser = namedtuple(
     ("room_id", "sender", "membership", "event_id", "stream_ordering")
 )
 
+GetRoomsForUserWithStreamOrdering = namedtuple(
+    "_GetRoomsForUserWithStreamOrdering",
+    ("room_id", "stream_ordering",)
+)
+
 
 # We store this using a namedtuple so that we save about 3x space over using a
 # dict.
@@ -48,97 +54,7 @@ ProfileInfo = namedtuple(
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
 
-class RoomMemberStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
-        )
-
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ]
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key, event.internal_metadata.stream_ordering
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened.
-            # The only current event that can also be an outlier is if its an
-            # invite that has come in across federation.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_invite_from_remote()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        }
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(sql, (
-                        event.internal_metadata.stream_ordering,
-                        event.event_id,
-                        event.room_id,
-                        event.state_key,
-                    ))
-
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (
-                stream_ordering,
-                True,
-                room_id,
-                user_id,
-            ))
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
+class RoomMemberWorkerStore(EventsWorkerStore):
     @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
     def get_hosts_in_room(self, room_id, cache_context):
         """Returns the set of all hosts currently in the room
@@ -270,12 +186,32 @@ class RoomMemberStore(SQLBaseStore):
         return results
 
     @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_rooms_for_user(self, user_id):
+    def get_rooms_for_user_with_stream_ordering(self, user_id):
         """Returns a set of room_ids the user is currently joined to
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+            the rooms the user is in currently, along with the stream ordering
+            of the most recent join for that user and room.
         """
         rooms = yield self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
+        defer.returnValue(frozenset(
+            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+            for r in rooms
+        ))
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id, on_invalidate=None):
+        """Returns a set of room_ids the user is currently joined to
+        """
+        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+            user_id, on_invalidate=on_invalidate,
+        )
         defer.returnValue(frozenset(r.room_id for r in rooms))
 
     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
@@ -295,89 +231,6 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(user_who_share_room)
 
-    def forget(self, user_id, room_id):
-        """Indicate that user_id wishes to discard history for room_id."""
-        def f(txn):
-            sql = (
-                "UPDATE"
-                "  room_memberships"
-                " SET"
-                "  forgotten = 1"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id))
-
-            txn.call_after(self.was_forgotten_at.invalidate_all)
-            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
-            self._invalidate_cache_and_stream(
-                txn, self.who_forgot_in_room, (room_id,)
-            )
-        return self.runInteraction("forget_membership", f)
-
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
-        """Returns whether user_id has elected to discard history for room_id.
-
-        Returns False if they have since re-joined."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  COUNT(*)"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  forgotten = 0"
-            )
-            txn.execute(sql, (user_id, room_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        count = yield self.runInteraction("did_forget_membership", f)
-        defer.returnValue(count == 0)
-
-    @cachedInlineCallbacks(num_args=3)
-    def was_forgotten_at(self, user_id, room_id, event_id):
-        """Returns whether user_id has elected to discard history for room_id at
-        event_id.
-
-        event_id must be a membership event."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  forgotten"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  event_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id, event_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        forgot = yield self.runInteraction("did_forget_membership_at", f)
-        defer.returnValue(forgot == 1)
-
-    @cached()
-    def who_forgot_in_room(self, room_id):
-        return self._simple_select_list(
-            table="room_memberships",
-            retcols=("user_id", "event_id"),
-            keyvalues={
-                "room_id": room_id,
-                "forgotten": 1,
-            },
-            desc="who_forgot"
-        )
-
     def get_joined_users_from_context(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -600,6 +453,185 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(joined_hosts)
 
+    @cached(max_entries=10000, iterable=True)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+    @cached()
+    def who_forgot_in_room(self, room_id):
+        return self._simple_select_list(
+            table="room_memberships",
+            retcols=("user_id", "event_id"),
+            keyvalues={
+                "room_id": room_id,
+                "forgotten": 1,
+            },
+            desc="who_forgot"
+        )
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
+
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ]
+        )
+
+        for event in events:
+            txn.call_after(
+                self._membership_stream_cache.entity_has_changed,
+                event.state_key, event.internal_metadata.stream_ordering
+            )
+            txn.call_after(
+                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+            )
+
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened.
+            # The only current event that can also be an outlier is if its an
+            # invite that has come in across federation.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_invite_from_remote()
+            )
+            is_mine = self.hs.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self._simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        }
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
+
+                    txn.execute(sql, (
+                        event.internal_metadata.stream_ordering,
+                        event.event_id,
+                        event.room_id,
+                        event.state_key,
+                    ))
+
+    @defer.inlineCallbacks
+    def locally_reject_invite(self, user_id, room_id):
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (
+                stream_ordering,
+                True,
+                room_id,
+                user_id,
+            ))
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+    def forget(self, user_id, room_id):
+        """Indicate that user_id wishes to discard history for room_id."""
+        def f(txn):
+            sql = (
+                "UPDATE"
+                "  room_memberships"
+                " SET"
+                "  forgotten = 1"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id))
+
+            txn.call_after(self.was_forgotten_at.invalidate_all)
+            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+            self._invalidate_cache_and_stream(
+                txn, self.who_forgot_in_room, (room_id,)
+            )
+        return self.runInteraction("forget_membership", f)
+
+    @cachedInlineCallbacks(num_args=2)
+    def did_forget(self, user_id, room_id):
+        """Returns whether user_id has elected to discard history for room_id.
+
+        Returns False if they have since re-joined."""
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  COUNT(*)"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  forgotten = 0"
+            )
+            txn.execute(sql, (user_id, room_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+        count = yield self.runInteraction("did_forget_membership", f)
+        defer.returnValue(count == 0)
+
+    @cachedInlineCallbacks(num_args=3)
+    def was_forgotten_at(self, user_id, room_id, event_id):
+        """Returns whether user_id has elected to discard history for room_id at
+        event_id.
+
+        event_id must be a membership event."""
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  forgotten"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  event_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id, event_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+        forgot = yield self.runInteraction("did_forget_membership_at", f)
+        defer.returnValue(forgot == 1)
+
     @defer.inlineCallbacks
     def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
@@ -675,10 +707,6 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(result)
 
-    @cached(max_entries=10000, iterable=True)
-    def _get_joined_hosts_cache(self, room_id):
-        return _JoinedHostsCache(self, room_id)
-
 
 class _JoinedHostsCache(object):
     """Cache for joined hosts in a room that is optimised to handle updates
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 8755bb2e49..4d725b92fe 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 
+import simplejson as json
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index 4269ac69ad..e7351c3ae6 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -17,7 +17,7 @@ import logging
 from synapse.storage.prepare_database import get_statements
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 71b12a2731..6df57b5206 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -16,7 +16,7 @@ import logging
 
 from synapse.storage.prepare_database import get_statements
 
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index 470ae0c005..fe6b7d196d 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.prepare_database import get_statements
 
 import logging
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "rows_inserted": 0,
             "have_added_indexes": False,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 83066cccc9..1e002f9db2 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -15,7 +15,7 @@
 from synapse.storage.prepare_database import get_statements
 
 import logging
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 2755acff40..984643b057 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -16,7 +16,7 @@
 from collections import namedtuple
 import logging
 import re
-import ujson as json
+import simplejson as json
 
 from twisted.internet import defer
 
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 67d5d9969a..9e6eaaa532 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.util.caches.descriptors import cached, cachedList
 
 
-class SignatureStore(SQLBaseStore):
-    """Persistence for event signatures and hashes"""
-
+class SignatureWorkerStore(SQLBaseStore):
     @cached()
     def get_event_reference_hash(self, event_id):
-        return self._get_event_reference_hashes_txn(event_id)
+        # This is a dummy function to allow get_event_reference_hashes
+        # to use its cache
+        raise NotImplementedError()
 
     @cachedList(cached_method_name="get_event_reference_hash",
                 list_name="event_ids", num_args=1)
@@ -74,6 +74,10 @@ class SignatureStore(SQLBaseStore):
         txn.execute(query, (event_id, ))
         return {k: v for k, v in txn}
 
+
+class SignatureStore(SignatureWorkerStore):
+    """Persistence for event signatures and hashes"""
+
     def _store_event_reference_hashes_txn(self, txn, events):
         """Store a hash for a PDU
         Args:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2b325e1c1f..ffa4246031 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -240,6 +240,9 @@ class StateGroupWorkerStore(SQLBaseStore):
                     (
                         "AND type = ? AND state_key = ?",
                         (etype, state_key)
+                    ) if state_key is not None else (
+                        "AND type = ?",
+                        (etype,)
                     )
                     for etype, state_key in types
                 ]
@@ -259,10 +262,19 @@ class StateGroupWorkerStore(SQLBaseStore):
                         key = (typ, state_key)
                         results[group][key] = event_id
         else:
+            where_args = []
+            where_clauses = []
+            wildcard_types = False
             if types is not None:
-                where_clause = "AND (%s)" % (
-                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                )
+                for typ in types:
+                    if typ[1] is None:
+                        where_clauses.append("(type = ?)")
+                        where_args.extend(typ[0])
+                        wildcard_types = True
+                    else:
+                        where_clauses.append("(type = ? AND state_key = ?)")
+                        where_args.extend([typ[0], typ[1]])
+                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
 
@@ -279,7 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
                     if types:
-                        args.extend(i for typ in types for i in typ)
+                        args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
@@ -292,9 +304,17 @@ class StateGroupWorkerStore(SQLBaseStore):
                         if (typ, state_key) not in results[group]
                     )
 
-                    # If the lengths match then we must have all the types,
-                    # so no need to go walk further down the tree.
-                    if types is not None and len(results[group]) == len(types):
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        types is not None and
+                        not wildcard_types and
+                        len(results[group]) == len(types)
+                    ):
                         break
 
                     next_group = self._simple_select_one_onecol_txn(
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 52bdce5be2..2956c3b3e0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,13 +35,16 @@ what sort order was used:
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+
 from synapse.util.caches.descriptors import cached
-from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
+import abc
 import logging
 
 
@@ -143,81 +146,41 @@ def filter_to_clause(event_filter):
     return " AND ".join(clauses), args
 
 
-class StreamStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
-        # NB this lives here instead of appservice.py so we can reuse the
-        # 'private' StreamToken class in this file.
-        if limit:
-            limit = max(limit, MAX_STREAM_SIZE)
-        else:
-            limit = MAX_STREAM_SIZE
-
-        # From and to keys should be integers from ordering.
-        from_id = RoomStreamToken.parse_stream_token(from_key)
-        to_id = RoomStreamToken.parse_stream_token(to_key)
-
-        if from_key == to_key:
-            defer.returnValue(([], to_key))
-            return
-
-        # select all the events between from/to with a sensible limit
-        sql = (
-            "SELECT e.event_id, e.room_id, e.type, s.state_key, "
-            "e.stream_ordering FROM events AS e "
-            "LEFT JOIN state_events as s ON "
-            "e.event_id = s.event_id "
-            "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
-            "ORDER BY stream_ordering ASC LIMIT %(limit)d "
-        ) % {
-            "limit": limit
-        }
-
-        def f(txn):
-            # pull out all the events between the tokens
-            txn.execute(sql, (from_id.stream, to_id.stream,))
-            rows = self.cursor_to_dict(txn)
-
-            # Logic:
-            #  - We want ALL events which match the AS room_id regex
-            #  - We want ALL events which match the rooms represented by the AS
-            #    room_alias regex
-            #  - We want ALL events for rooms that AS users have joined.
-            # This is currently supported via get_app_service_rooms (which is
-            # used for the Notifier listener rooms). We can't reasonably make a
-            # SQL query for these room IDs, so we'll pull all the events between
-            # from/to and filter in python.
-            rooms_for_as = self._get_app_service_rooms_txn(txn, service)
-            room_ids_for_as = [r.room_id for r in rooms_for_as]
-
-            def app_service_interested(row):
-                if row["room_id"] in room_ids_for_as:
-                    return True
-
-                if row["type"] == EventTypes.Member:
-                    if service.is_interested_in_user(row.get("state_key")):
-                        return True
-                return False
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
+    which can be called in the initializer.
+    """
 
-            return [r for r in rows if app_service_interested(r)]
+    __metaclass__ = abc.ABCMeta
 
-        rows = yield self.runInteraction("get_appservice_room_stream", f)
+    def __init__(self, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(db_conn, hs)
 
-        ret = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
+        events_max = self.get_room_max_stream_ordering()
+        event_cache_prefill, min_event_val = self._get_cache_dict(
+            db_conn, "events",
+            entity_column="room_id",
+            stream_column="stream_ordering",
+            max_value=events_max,
+        )
+        self._events_stream_cache = StreamChangeCache(
+            "EventsRoomStreamChangeCache", min_event_val,
+            prefilled_cache=event_cache_prefill,
+        )
+        self._membership_stream_cache = StreamChangeCache(
+            "MembershipStreamChangeCache", events_max,
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._stream_order_on_start = self.get_room_max_stream_ordering()
 
-        if rows:
-            key = "s%d" % max(r["stream_ordering"] for r in rows)
-        else:
-            # Assume we didn't get anything because there was nothing to
-            # get.
-            key = to_key
+    @abc.abstractmethod
+    def get_room_max_stream_ordering(self):
+        raise NotImplementedError()
 
-        defer.returnValue((ret, key))
+    @abc.abstractmethod
+    def get_room_min_stream_ordering(self):
+        raise NotImplementedError()
 
     @defer.inlineCallbacks
     def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -381,88 +344,6 @@ class StreamStore(SQLBaseStore):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def paginate_room_events(self, room_id, from_key, to_key=None,
-                             direction='b', limit=-1, event_filter=None):
-        # Tokens really represent positions between elements, but we use
-        # the convention of pointing to the event before the gap. Hence
-        # we have a bit of asymmetry when it comes to equalities.
-        args = [False, room_id]
-        if direction == 'b':
-            order = "DESC"
-            bounds = upper_bound(
-                RoomStreamToken.parse(from_key), self.database_engine
-            )
-            if to_key:
-                bounds = "%s AND %s" % (bounds, lower_bound(
-                    RoomStreamToken.parse(to_key), self.database_engine
-                ))
-        else:
-            order = "ASC"
-            bounds = lower_bound(
-                RoomStreamToken.parse(from_key), self.database_engine
-            )
-            if to_key:
-                bounds = "%s AND %s" % (bounds, upper_bound(
-                    RoomStreamToken.parse(to_key), self.database_engine
-                ))
-
-        filter_clause, filter_args = filter_to_clause(event_filter)
-
-        if filter_clause:
-            bounds += " AND " + filter_clause
-            args.extend(filter_args)
-
-        if int(limit) > 0:
-            args.append(int(limit))
-            limit_str = " LIMIT ?"
-        else:
-            limit_str = ""
-
-        sql = (
-            "SELECT * FROM events"
-            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
-            " ORDER BY topological_ordering %(order)s,"
-            " stream_ordering %(order)s %(limit)s"
-        ) % {
-            "bounds": bounds,
-            "order": order,
-            "limit": limit_str
-        }
-
-        def f(txn):
-            txn.execute(sql, args)
-
-            rows = self.cursor_to_dict(txn)
-
-            if rows:
-                topo = rows[-1]["topological_ordering"]
-                toke = rows[-1]["stream_ordering"]
-                if direction == 'b':
-                    # Tokens are positions between events.
-                    # This token points *after* the last event in the chunk.
-                    # We need it to point to the event before it in the chunk
-                    # when we are going backwards so we subtract one from the
-                    # stream part.
-                    toke -= 1
-                next_token = str(RoomStreamToken(topo, toke))
-            else:
-                # TODO (erikj): We should work out what to do here instead.
-                next_token = to_key if to_key else from_key
-
-            return rows, next_token,
-
-        rows, token = yield self.runInteraction("paginate_room_events", f)
-
-        events = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
-        )
-
-        self._set_before_and_after(events, rows)
-
-        defer.returnValue((events, token))
-
-    @defer.inlineCallbacks
     def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
         rows, token = yield self.get_recent_event_ids_for_room(
             room_id, limit, end_token, from_token
@@ -534,6 +415,33 @@ class StreamStore(SQLBaseStore):
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
+    def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
+        """Gets details of the first event in a room at or after a stream ordering
+
+        Args:
+            room_id (str):
+            stream_ordering (int):
+
+        Returns:
+            Deferred[(int, int, str)]:
+                (stream ordering, topological ordering, event_id)
+        """
+        def _f(txn):
+            sql = (
+                "SELECT stream_ordering, topological_ordering, event_id"
+                " FROM events"
+                " WHERE room_id = ? AND stream_ordering >= ?"
+                " AND NOT outlier"
+                " ORDER BY stream_ordering"
+                " LIMIT 1"
+            )
+            txn.execute(sql, (room_id, stream_ordering, ))
+            return txn.fetchone()
+
+        return self.runInteraction(
+            "get_room_event_after_stream_ordering", _f,
+        )
+
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
         """Returns the current token for rooms stream.
@@ -542,7 +450,7 @@ class StreamStore(SQLBaseStore):
         `room_id` causes it to return the current room specific topological
         token.
         """
-        token = yield self._stream_id_gen.get_current_token()
+        token = yield self.get_room_max_stream_ordering()
         if room_id is None:
             defer.returnValue("s%d" % (token,))
         else:
@@ -552,12 +460,6 @@ class StreamStore(SQLBaseStore):
             )
             defer.returnValue("t%d-%d" % (topo, token))
 
-    def get_room_max_stream_ordering(self):
-        return self._stream_id_gen.get_current_token()
-
-    def get_room_min_stream_ordering(self):
-        return self._backfill_id_gen.get_current_token()
-
     def get_stream_token_for_event(self, event_id):
         """The stream token for an event
         Args:
@@ -832,3 +734,93 @@ class StreamStore(SQLBaseStore):
 
     def has_room_changed_since(self, room_id, stream_id):
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+
+
+class StreamStore(StreamWorkerStore):
+    def get_room_max_stream_ordering(self):
+        return self._stream_id_gen.get_current_token()
+
+    def get_room_min_stream_ordering(self):
+        return self._backfill_id_gen.get_current_token()
+
+    @defer.inlineCallbacks
+    def paginate_room_events(self, room_id, from_key, to_key=None,
+                             direction='b', limit=-1, event_filter=None):
+        # Tokens really represent positions between elements, but we use
+        # the convention of pointing to the event before the gap. Hence
+        # we have a bit of asymmetry when it comes to equalities.
+        args = [False, room_id]
+        if direction == 'b':
+            order = "DESC"
+            bounds = upper_bound(
+                RoomStreamToken.parse(from_key), self.database_engine
+            )
+            if to_key:
+                bounds = "%s AND %s" % (bounds, lower_bound(
+                    RoomStreamToken.parse(to_key), self.database_engine
+                ))
+        else:
+            order = "ASC"
+            bounds = lower_bound(
+                RoomStreamToken.parse(from_key), self.database_engine
+            )
+            if to_key:
+                bounds = "%s AND %s" % (bounds, upper_bound(
+                    RoomStreamToken.parse(to_key), self.database_engine
+                ))
+
+        filter_clause, filter_args = filter_to_clause(event_filter)
+
+        if filter_clause:
+            bounds += " AND " + filter_clause
+            args.extend(filter_args)
+
+        if int(limit) > 0:
+            args.append(int(limit))
+            limit_str = " LIMIT ?"
+        else:
+            limit_str = ""
+
+        sql = (
+            "SELECT * FROM events"
+            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
+            " ORDER BY topological_ordering %(order)s,"
+            " stream_ordering %(order)s %(limit)s"
+        ) % {
+            "bounds": bounds,
+            "order": order,
+            "limit": limit_str
+        }
+
+        def f(txn):
+            txn.execute(sql, args)
+
+            rows = self.cursor_to_dict(txn)
+
+            if rows:
+                topo = rows[-1]["topological_ordering"]
+                toke = rows[-1]["stream_ordering"]
+                if direction == 'b':
+                    # Tokens are positions between events.
+                    # This token points *after* the last event in the chunk.
+                    # We need it to point to the event before it in the chunk
+                    # when we are going backwards so we subtract one from the
+                    # stream part.
+                    toke -= 1
+                next_token = str(RoomStreamToken(topo, toke))
+            else:
+                # TODO (erikj): We should work out what to do here instead.
+                next_token = to_key if to_key else from_key
+
+            return rows, next_token,
+
+        rows, token = yield self.runInteraction("paginate_room_events", f)
+
+        events = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+
+        self._set_before_and_after(events, rows)
+
+        defer.returnValue((events, token))
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bff73f3f04..13bff9f055 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,25 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
 from synapse.util.caches.descriptors import cached
 from twisted.internet import defer
 
-import ujson as json
+import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class TagsStore(SQLBaseStore):
-    def get_max_account_data_stream_id(self):
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            A deferred int.
-        """
-        return self._account_data_id_gen.get_current_token()
-
+class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
     def get_tags_for_user(self, user_id):
         """Get all the tags for a user.
@@ -170,6 +164,8 @@ class TagsStore(SQLBaseStore):
             row["tag"]: json.loads(row["content"]) for row in rows
         })
 
+
+class TagsStore(TagsWorkerStore):
     @defer.inlineCallbacks
     def add_tag_to_room(self, user_id, room_id, tag, content):
         """Add a tag to a room for a user.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 8f61f7ffae..f825264ea9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json
 from collections import namedtuple
 
 import logging
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index dfdcbb3181..d6e289ffbe 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -667,7 +667,7 @@ class UserDirectoryStore(SQLBaseStore):
             # The array of numbers are the weights for the various part of the
             # search: (domain, _, display name, localpart)
             sql = """
-                SELECT d.user_id, display_name, avatar_url
+                SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
                 %s
@@ -702,7 +702,7 @@ class UserDirectoryStore(SQLBaseStore):
             search_query = _parse_query_sqlite(search_term)
 
             sql = """
-                SELECT d.user_id, display_name, avatar_url
+                SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
                 %s