summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorMichael Telatynski <7t3chguy@gmail.com>2018-07-24 17:17:46 +0100
committerMichael Telatynski <7t3chguy@gmail.com>2018-07-24 17:17:46 +0100
commit87951d3891efb5bccedf72c12b3da0d6ab482253 (patch)
treede7d997567c66c5a4d8743c1f3b9d6b474f5cfd9 /synapse/storage
parentif inviter_display_name == ""||None then default to inviter MXID (diff)
parentMerge pull request #3595 from matrix-org/erikj/use_deltas (diff)
downloadsynapse-87951d3891efb5bccedf72c12b3da0d6ab482253.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into t3chguy/default_inviter_display_name_3pid
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py291
-rw-r--r--synapse/storage/_base.py334
-rw-r--r--synapse/storage/account_data.py212
-rw-r--r--synapse/storage/appservice.py112
-rw-r--r--synapse/storage/background_updates.py82
-rw-r--r--synapse/storage/client_ips.py56
-rw-r--r--synapse/storage/deviceinbox.py20
-rw-r--r--synapse/storage/devices.py74
-rw-r--r--synapse/storage/directory.py59
-rw-r--r--synapse/storage/end_to_end_keys.py38
-rw-r--r--synapse/storage/engines/__init__.py10
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite3.py23
-rw-r--r--synapse/storage/event_federation.py333
-rw-r--r--synapse/storage/event_push_actions.py486
-rw-r--r--synapse/storage/events.py1240
-rw-r--r--synapse/storage/events_worker.py436
-rw-r--r--synapse/storage/filtering.py10
-rw-r--r--synapse/storage/group_server.py1252
-rw-r--r--synapse/storage/keys.py69
-rw-r--r--synapse/storage/media_repository.py117
-rw-r--r--synapse/storage/prepare_database.py76
-rw-r--r--synapse/storage/presence.py15
-rw-r--r--synapse/storage/profile.py155
-rw-r--r--synapse/storage/push_rule.py113
-rw-r--r--synapse/storage/pusher.py105
-rw-r--r--synapse/storage/receipts.py223
-rw-r--r--synapse/storage/registration.py311
-rw-r--r--synapse/storage/rejections.py4
-rw-r--r--synapse/storage/room.py475
-rw-r--r--synapse/storage/roommember.py427
-rw-r--r--synapse/storage/schema/delta/14/upgrade_appservice_db.py3
-rw-r--r--synapse/storage/schema/delta/25/fts.py8
-rw-r--r--synapse/storage/schema/delta/27/ts.py6
-rw-r--r--synapse/storage/schema/delta/30/as_users.py8
-rw-r--r--synapse/storage/schema/delta/31/search_update.py9
-rw-r--r--synapse/storage/schema/delta/33/event_fields.py9
-rw-r--r--synapse/storage/schema/delta/33/remote_media_ts.py1
-rw-r--r--synapse/storage/schema/delta/34/cache_stream.py6
-rw-r--r--synapse/storage/schema/delta/34/received_txn_purge.py4
-rw-r--r--synapse/storage/schema/delta/34/sent_txn_purge.py4
-rw-r--r--synapse/storage/schema/delta/37/remove_auth_idx.py6
-rw-r--r--synapse/storage/schema/delta/38/postgres_fts_gist.sql6
-rw-r--r--synapse/storage/schema/delta/42/user_dir.py2
-rw-r--r--synapse/storage/schema/delta/43/user_share.sql2
-rw-r--r--synapse/storage/schema/delta/44/expire_url_cache.sql41
-rw-r--r--synapse/storage/schema/delta/45/group_server.sql167
-rw-r--r--synapse/storage/schema/delta/45/profile_cache.sql28
-rw-r--r--synapse/storage/schema/delta/46/drop_refresh_tokens.sql (renamed from synapse/storage/schema/delta/33/refreshtoken_device_index.sql)6
-rw-r--r--synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql35
-rw-r--r--synapse/storage/schema/delta/46/group_server.sql32
-rw-r--r--synapse/storage/schema/delta/46/local_media_repository_url_idx.sql24
-rw-r--r--synapse/storage/schema/delta/46/user_dir_null_room_ids.sql35
-rw-r--r--synapse/storage/schema/delta/46/user_dir_typos.sql24
-rw-r--r--synapse/storage/schema/delta/47/last_access_media.sql (renamed from synapse/storage/schema/delta/33/refreshtoken_device.sql)4
-rw-r--r--synapse/storage/schema/delta/47/postgres_fts_gin.sql (renamed from synapse/storage/schema/delta/23/refresh_tokens.sql)10
-rw-r--r--synapse/storage/schema/delta/47/push_actions_staging.sql28
-rw-r--r--synapse/storage/schema/delta/47/state_group_seq.py37
-rw-r--r--synapse/storage/schema/delta/48/add_user_consent.sql18
-rw-r--r--synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql17
-rw-r--r--synapse/storage/schema/delta/48/deactivated_users.sql25
-rw-r--r--synapse/storage/schema/delta/48/group_unique_indexes.py57
-rw-r--r--synapse/storage/schema/delta/48/groups_joinable.sql22
-rw-r--r--synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql20
-rw-r--r--synapse/storage/schema/delta/49/add_user_daily_visits.sql21
-rw-r--r--synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql17
-rw-r--r--synapse/storage/schema/delta/50/add_creation_ts_users_index.sql19
-rw-r--r--synapse/storage/schema/delta/50/erasure_store.sql21
-rw-r--r--synapse/storage/schema/schema_version.sql7
-rw-r--r--synapse/storage/search.py214
-rw-r--r--synapse/storage/signatures.py30
-rw-r--r--synapse/storage/state.py597
-rw-r--r--synapse/storage/stream.py683
-rw-r--r--synapse/storage/tags.py27
-rw-r--r--synapse/storage/transactions.py26
-rw-r--r--synapse/storage/user_directory.py99
-rw-r--r--synapse/storage/user_erasure_store.py103
-rw-r--r--synapse/storage/util/id_generators.py2
78 files changed, 6793 insertions, 2941 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b92472df33..ba88a54979 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,53 +14,49 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+import datetime
+import logging
+import time
+
+from dateutil import tz
 
+from synapse.api.constants import PresenceState
 from synapse.storage.devices import DeviceStore
-from .appservice import (
-    ApplicationServiceStore, ApplicationServiceTransactionStore
-)
-from ._base import LoggingTransaction
+from synapse.storage.user_erasure_store import UserErasureStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
 from .directory import DirectoryStore
+from .end_to_end_keys import EndToEndKeyStore
+from .engines import PostgresEngine
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
 from .events import EventsStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .openid import OpenIdStore
 from .presence import PresenceStore, UserPresenceState
 from .profile import ProfileStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
 from .registration import RegistrationStore
+from .rejections import RejectionsStore
 from .room import RoomStore
 from .roommember import RoomMemberStore
-from .stream import StreamStore
-from .transactions import TransactionStore
-from .keys import KeyStore
-from .event_federation import EventFederationStore
-from .pusher import PusherStore
-from .push_rule import PushRuleStore
-from .media_repository import MediaRepositoryStore
-from .rejections import RejectionsStore
-from .event_push_actions import EventPushActionsStore
-from .deviceinbox import DeviceInboxStore
-
-from .state import StateStore
-from .signatures import SignatureStore
-from .filtering import FilteringStore
-from .end_to_end_keys import EndToEndKeyStore
-
-from .receipts import ReceiptsStore
 from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stream import StreamStore
 from .tags import TagsStore
-from .account_data import AccountDataStore
-from .openid import OpenIdStore
-from .client_ips import ClientIpStore
+from .transactions import TransactionStore
 from .user_directory import UserDirectoryStore
-
-from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
-from .engines import PostgresEngine
-
-from synapse.api.constants import PresenceState
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-
-import logging
-
+from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -88,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 DeviceStore,
                 DeviceInboxStore,
                 UserDirectoryStore,
+                GroupServerStore,
+                UserErasureStore,
                 ):
 
     def __init__(self, db_conn, hs):
@@ -103,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "events", "stream_ordering", step=-1,
             extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
         )
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn, "account_data_max_stream_id", "stream_id"
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -123,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@@ -135,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "pushers", "id",
             extra_tables=[("deleted_pushers", "stream_id")],
         )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id",
+        )
 
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = StreamIdGenerator(
@@ -143,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore,
         else:
             self._cache_id_gen = None
 
-        events_max = self._stream_id_gen.get_current_token()
-        event_cache_prefill, min_event_val = self._get_cache_dict(
-            db_conn, "events",
-            entity_column="room_id",
-            stream_column="stream_ordering",
-            max_value=events_max,
-        )
-        self._events_stream_cache = StreamChangeCache(
-            "EventsRoomStreamChangeCache", min_event_val,
-            prefilled_cache=event_cache_prefill,
-        )
-
-        self._membership_stream_cache = StreamChangeCache(
-            "MembershipStreamChangeCache", events_max,
-        )
-
-        account_max = self._account_data_id_gen.get_current_token()
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache", account_max,
-        )
-
         self._presence_on_startup = self._get_active_presence(db_conn)
 
         presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -177,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=presence_cache_prefill
         )
 
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn, "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._push_rules_stream_id_gen.get_current_token()[0],
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache", push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
             db_conn, "device_inbox",
@@ -223,6 +185,7 @@ class DataStore(RoomMemberStore, RoomStore,
             "DeviceListFederationStreamChangeCache", device_list_max,
         )
 
+        events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
             db_conn, "current_state_delta_stream",
             entity_column="room_id",
@@ -235,24 +198,25 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=curr_state_delta_prefill,
         )
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-            after_callbacks=[],
-            final_callbacks=[],
+        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+            db_conn, "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
         )
-        self._find_stream_orderings_for_times_txn(cur)
-        cur.close()
-
-        self.find_stream_orderings_looping_call = self._clock.looping_call(
-            self._find_stream_orderings_for_times, 10 * 60 * 1000
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
         )
 
         self._stream_order_on_start = self.get_room_max_stream_ordering()
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
-        super(DataStore, self).__init__(hs)
+        # Used in _generate_user_daily_visits to keep track of progress
+        self._last_user_visit_update = self._get_start_of_day()
+
+        super(DataStore, self).__init__(db_conn, hs)
 
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
@@ -281,13 +245,12 @@ class DataStore(RoomMemberStore, RoomStore,
 
         return [UserPresenceState(**row) for row in rows]
 
-    @defer.inlineCallbacks
     def count_daily_users(self):
         """
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         def _count_users(txn):
-            yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),
+            yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
 
             sql = """
                 SELECT COALESCE(count(*), 0) FROM (
@@ -301,8 +264,154 @@ class DataStore(RoomMemberStore, RoomStore,
             count, = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_users", _count_users)
-        defer.returnValue(ret)
+        return self.runInteraction("count_users", _count_users)
+
+    def count_r30_users(self):
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+         Returns counts globaly for a given user as well as breaking
+         by platform
+        """
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs,
+                              thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] is 'unknown':
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs,
+                              thirty_days_ago_in_secs))
+
+            count, = txn.fetchone()
+            results['all'] = count
+
+            return results
+
+        return self.runInteraction("count_r30_users", _count_r30_users)
+
+    def _get_start_of_day(self):
+        """
+        Returns millisecond unixtime for start of UTC day.
+        """
+        now = datetime.datetime.utcnow()
+        today_start = datetime.datetime(now.year, now.month,
+                                        now.day, tzinfo=tz.tzutc())
+        return int(time.mktime(today_start.timetuple())) * 1000
+
+    def generate_user_daily_visits(self):
+        """
+        Generates daily visit data for use in cohort/ retention analysis
+        """
+        def _generate_user_daily_visits(txn):
+            logger.info("Calling _generate_user_daily_visits")
+            today_start = self._get_start_of_day()
+            a_day_in_milliseconds = 24 * 60 * 60 * 1000
+            now = self.clock.time_msec()
+
+            sql = """
+                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+                    SELECT u.user_id, u.device_id, ?
+                    FROM user_ips AS u
+                    LEFT JOIN (
+                      SELECT user_id, device_id, timestamp FROM user_daily_visits
+                      WHERE timestamp = ?
+                    ) udv
+                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+                    INNER JOIN users ON users.name=u.user_id
+                    WHERE last_seen > ? AND last_seen <= ?
+                    AND udv.timestamp IS NULL AND users.is_guest=0
+                    AND users.appservice_id IS NULL
+                    GROUP BY u.user_id, u.device_id
+            """
+
+            # This means that the day has rolled over but there could still
+            # be entries from the previous day. There is an edge case
+            # where if the user logs in at 23:59 and overwrites their
+            # last_seen at 00:01 then they will not be counted in the
+            # previous day's stats - it is important that the query is run
+            # often to minimise this case.
+            if today_start > self._last_user_visit_update:
+                yesterday_start = today_start - a_day_in_milliseconds
+                txn.execute(sql, (
+                    yesterday_start, yesterday_start,
+                    self._last_user_visit_update, today_start
+                ))
+                self._last_user_visit_update = today_start
+
+            txn.execute(sql, (
+                today_start, today_start,
+                self._last_user_visit_update,
+                now
+            ))
+            # Update _last_user_visit_update to now. The reason to do this
+            # rather just clamping to the beginning of the day is to limit
+            # the size of the join - meaning that the query can be run more
+            # frequently
+            self._last_user_visit_update = now
+
+        return self.runInteraction("generate_user_daily_visits",
+                                   _generate_user_daily_visits)
 
     def get_users(self):
         """Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6f54036d67..1d41d8d445 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -13,36 +13,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import sys
+import threading
+import time
 
-from synapse.api.errors import StoreError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches import CACHE_SIZE_FACTOR
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.caches.descriptors import Cache
-from synapse.storage.engines import PostgresEngine
-import synapse.metrics
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
 
+from prometheus_client import Histogram
 
 from twisted.internet import defer
 
-import sys
-import time
-import threading
-
+from synapse.api.errors import StoreError
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import Cache
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 
 logger = logging.getLogger(__name__)
 
+try:
+    MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+    # python 3 does not have a maximum int value
+    MAX_TXN_ID = 2**63 - 1
+
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
 perf_logger = logging.getLogger("synapse.storage.TIME")
 
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
 
-metrics = synapse.metrics.get_metrics_for("synapse.storage")
-
-sql_scheduling_timer = metrics.register_distribution("schedule_time")
-
-sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
-sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
 
 
 class LoggingTransaction(object):
@@ -50,16 +52,16 @@ class LoggingTransaction(object):
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
     __slots__ = [
-        "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+        "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
     ]
 
     def __init__(self, txn, name, database_engine, after_callbacks,
-                 final_callbacks):
+                 exception_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
         object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "final_callbacks", final_callbacks)
+        object.__setattr__(self, "exception_callbacks", exception_callbacks)
 
     def call_after(self, callback, *args, **kwargs):
         """Call the given callback on the main twisted thread after the
@@ -68,8 +70,8 @@ class LoggingTransaction(object):
         """
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_finally(self, callback, *args, **kwargs):
-        self.final_callbacks.append((callback, args, kwargs))
+    def call_on_exception(self, callback, *args, **kwargs):
+        self.exception_callbacks.append((callback, args, kwargs))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -103,11 +105,11 @@ class LoggingTransaction(object):
                     "[SQL values] {%s} %r",
                     self.name, args[0]
                 )
-            except:
+            except Exception:
                 # Don't let logging failures stop SQL from working
                 pass
 
-        start = time.time() * 1000
+        start = time.time()
 
         try:
             return func(
@@ -117,9 +119,9 @@ class LoggingTransaction(object):
             logger.debug("[SQL FAIL] {%s} %s", self.name, e)
             raise
         finally:
-            msecs = (time.time() * 1000) - start
-            sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
-            sql_query_timer.inc_by(msecs, sql.split()[0])
+            secs = time.time() - start
+            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+            sql_query_timer.labels(sql.split()[0]).observe(secs)
 
 
 class PerformanceCounters(object):
@@ -129,7 +131,7 @@ class PerformanceCounters(object):
 
     def update(self, key, start_time, end_time=None):
         if end_time is None:
-            end_time = time.time() * 1000
+            end_time = time.time()
         duration = end_time - start_time
         count, cum_time = self.current_counters.get(key, (0, 0))
         count += 1
@@ -139,7 +141,7 @@ class PerformanceCounters(object):
 
     def interval(self, interval_duration, limit=3):
         counters = []
-        for name, (count, cum_time) in self.current_counters.iteritems():
+        for name, (count, cum_time) in iteritems(self.current_counters):
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
             counters.append((
                 (cum_time - prev_time) / interval_duration,
@@ -162,7 +164,7 @@ class PerformanceCounters(object):
 class SQLBaseStore(object):
     _TXN_ID = 0
 
-    def __init__(self, hs):
+    def __init__(self, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
         self._db_pool = hs.get_db_pool()
@@ -180,10 +182,6 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3,
                                       max_entries=hs.config.event_cache_size)
 
-        self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
-        )
-
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
@@ -221,14 +219,14 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
-    def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
-                         logging_context, func, *args, **kwargs):
-        start = time.time() * 1000
+    def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
+                         func, *args, **kwargs):
+        start = time.time()
         txn_id = self._TXN_ID
 
         # We don't really need these to be unique, so lets stop it from
         # growing really large.
-        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
 
         name = "%s-%x" % (desc, txn_id, )
 
@@ -242,7 +240,7 @@ class SQLBaseStore(object):
                     txn = conn.cursor()
                     txn = LoggingTransaction(
                         txn, name, self.database_engine, after_callbacks,
-                        final_callbacks,
+                        exception_callbacks,
                     )
                     r = func(txn, *args, **kwargs)
                     conn.commit()
@@ -283,73 +281,85 @@ class SQLBaseStore(object):
             logger.debug("[TXN FAIL] {%s} %s", name, e)
             raise
         finally:
-            end = time.time() * 1000
+            end = time.time()
             duration = end - start
 
-            if logging_context is not None:
-                logging_context.add_database_transaction(duration)
+            LoggingContext.current_context().add_database_transaction(duration)
 
-            transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
 
             self._current_txn_total_time += duration
             self._txn_perf_counters.update(desc, start, end)
-            sql_txn_timer.inc_by(duration, desc)
+            sql_txn_timer.labels(desc).observe(duration)
 
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
-        current_context = LoggingContext.current_context()
-
-        start_time = time.time() * 1000
+        """Starts a transaction on the database and runs a given function
 
-        after_callbacks = []
-        final_callbacks = []
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
 
-        def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runInteraction") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
 
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                current_context.copy_to(context)
-                return self._new_transaction(
-                    conn, desc, after_callbacks, final_callbacks, current_context,
-                    func, *args, **kwargs
-                )
+        Returns:
+            Deferred: The result of func
+        """
+        after_callbacks = []
+        exception_callbacks = []
 
         try:
-            with PreserveLoggingContext():
-                result = yield self._db_pool.runWithConnection(
-                    inner_func, *args, **kwargs
-                )
+            result = yield self.runWithConnection(
+                self._new_transaction,
+                desc, after_callbacks, exception_callbacks, func,
+                *args, **kwargs
+            )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
-        finally:
-            for after_callback, after_args, after_kwargs in final_callbacks:
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
                 after_callback(*after_args, **after_kwargs)
+            raise
 
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
-        current_context = LoggingContext.current_context()
+        """Wraps the .runWithConnection() method on the underlying db_pool.
 
-        start_time = time.time() * 1000
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        parent_context = LoggingContext.current_context()
+        if parent_context == LoggingContext.sentinel:
+            logger.warn(
+                "Starting db connection from sentinel context: metrics will be lost",
+            )
+            parent_context = None
+
+        start_time = time.time()
 
         def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runWithConnection") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+            with LoggingContext("runWithConnection", parent_context) as context:
+                sched_duration_sec = time.time() - start_time
+                sql_scheduling_timer.observe(sched_duration_sec)
+                context.add_database_scheduled(sched_duration_sec)
 
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
-                current_context.copy_to(context)
-
                 return func(conn, *args, **kwargs)
 
         with PreserveLoggingContext():
@@ -368,7 +378,7 @@ class SQLBaseStore(object):
         Returns:
             A list of dicts where the key is the column header.
         """
-        col_headers = list(intern(column[0]) for column in cursor.description)
+        col_headers = list(intern(str(column[0])) for column in cursor.description)
         results = list(
             dict(zip(col_headers, row)) for row in cursor
         )
@@ -475,23 +485,53 @@ class SQLBaseStore(object):
 
         txn.executemany(sql, vals)
 
+    @defer.inlineCallbacks
     def _simple_upsert(self, table, keyvalues, values,
                        insertion_values={}, desc="_simple_upsert", lock=True):
         """
+
+        `lock` should generally be set to True (the default), but can be set
+        to False if either of the following are true:
+
+        * there is a UNIQUE INDEX on the key columns. In this case a conflict
+          will cause an IntegrityError in which case this function will retry
+          the update.
+
+        * we somehow know that we are the only thread which will be updating
+          this table.
+
         Args:
             table (str): The table to upsert into
             keyvalues (dict): The unique key tables and their new values
             values (dict): The nonunique columns and their new values
-            insertion_values (dict): key/values to use when inserting
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
         Returns:
             Deferred(bool): True if a new entry was created, False if an
                 existing one was updated.
         """
-        return self.runInteraction(
-            desc,
-            self._simple_upsert_txn, table, keyvalues, values, insertion_values,
-            lock
-        )
+        attempts = 0
+        while True:
+            try:
+                result = yield self.runInteraction(
+                    desc,
+                    self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+                    lock=lock
+                )
+                defer.returnValue(result)
+            except self.database_engine.module.IntegrityError as e:
+                attempts += 1
+                if attempts >= 5:
+                    # don't retry forever, because things other than races
+                    # can cause IntegrityErrors
+                    raise
+
+                # presumably we raced with another transaction: let's retry.
+                logger.warn(
+                    "IntegrityError when upserting into %s; retrying: %s",
+                    table, e
+                )
 
     def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
                            lock=True):
@@ -499,37 +539,38 @@ class SQLBaseStore(object):
         if lock:
             self.database_engine.lock_table(txn, table)
 
-        # Try to update
+        # First try to update.
         sql = "UPDATE %s SET %s WHERE %s" % (
             table,
             ", ".join("%s = ?" % (k,) for k in values),
             " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
-        sqlargs = values.values() + keyvalues.values()
+        sqlargs = list(values.values()) + list(keyvalues.values())
 
         txn.execute(sql, sqlargs)
-        if txn.rowcount == 0:
-            # We didn't update and rows so insert a new one
-            allvalues = {}
-            allvalues.update(keyvalues)
-            allvalues.update(values)
-            allvalues.update(insertion_values)
+        if txn.rowcount > 0:
+            # successfully updated at least one row.
+            return False
 
-            sql = "INSERT INTO %s (%s) VALUES (%s)" % (
-                table,
-                ", ".join(k for k in allvalues),
-                ", ".join("?" for _ in allvalues)
-            )
-            txn.execute(sql, allvalues.values())
+        # We didn't update any rows so insert a new one
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
 
-            return True
-        else:
-            return False
+        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues)
+        )
+        txn.execute(sql, list(allvalues.values()))
+        # successfully inserted
+        return True
 
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning a single column from it.
+        return a single row, returning multiple columns from it.
 
         Args:
             table : string giving the table name
@@ -582,20 +623,18 @@ class SQLBaseStore(object):
 
     @staticmethod
     def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
-        else:
-            where = ""
-
         sql = (
-            "SELECT %(retcol)s FROM %(table)s %(where)s"
+            "SELECT %(retcol)s FROM %(table)s"
         ) % {
             "retcol": retcol,
             "table": table,
-            "where": where,
         }
 
-        txn.execute(sql, keyvalues.values())
+        if keyvalues:
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            txn.execute(sql)
 
         return [r[0] for r in txn]
 
@@ -606,7 +645,7 @@ class SQLBaseStore(object):
 
         Args:
             table (str): table name
-            keyvalues (dict): column names and values to select the rows with
+            keyvalues (dict|None): column names and values to select the rows with
             retcol (str): column whos value we wish to retrieve.
 
         Returns:
@@ -657,7 +696,7 @@ class SQLBaseStore(object):
                 table,
                 " AND ".join("%s = ?" % (k, ) for k in keyvalues)
             )
-            txn.execute(sql, keyvalues.values())
+            txn.execute(sql, list(keyvalues.values()))
         else:
             sql = "SELECT %s FROM %s" % (
                 ", ".join(retcols),
@@ -688,9 +727,12 @@ class SQLBaseStore(object):
         if not iterable:
             defer.returnValue(results)
 
+        # iterables can not be sliced, so convert it to a list first
+        it_list = list(iterable)
+
         chunks = [
-            iterable[i:i + batch_size]
-            for i in xrange(0, len(iterable), batch_size)
+            it_list[i:i + batch_size]
+            for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
             rows = yield self.runInteraction(
@@ -730,7 +772,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.iteritems():
+        for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -743,6 +785,33 @@ class SQLBaseStore(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
+    def _simple_update(self, table, keyvalues, updatevalues, desc):
+        return self.runInteraction(
+            desc,
+            self._simple_update_txn,
+            table, keyvalues, updatevalues,
+        )
+
+    @staticmethod
+    def _simple_update_txn(txn, table, keyvalues, updatevalues):
+        if keyvalues:
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+        else:
+            where = ""
+
+        update_sql = "UPDATE %s SET %s %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            where,
+        )
+
+        txn.execute(
+            update_sql,
+            list(updatevalues.values()) + list(keyvalues.values())
+        )
+
+        return txn.rowcount
+
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
         """Executes an UPDATE query on the named table, setting new values for
@@ -768,27 +837,13 @@ class SQLBaseStore(object):
             table, keyvalues, updatevalues,
         )
 
-    @staticmethod
-    def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
-        else:
-            where = ""
-
-        update_sql = "UPDATE %s SET %s %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in updatevalues),
-            where,
-        )
-
-        txn.execute(
-            update_sql,
-            updatevalues.values() + keyvalues.values()
-        )
+    @classmethod
+    def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+        rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
 
-        if txn.rowcount == 0:
+        if rowcount == 0:
             raise StoreError(404, "No row found")
-        if txn.rowcount > 1:
+        if rowcount > 1:
             raise StoreError(500, "More than one row matched")
 
     @staticmethod
@@ -800,7 +855,7 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
 
-        txn.execute(select_sql, keyvalues.values())
+        txn.execute(select_sql, list(keyvalues.values()))
 
         row = txn.fetchone()
         if not row:
@@ -838,7 +893,7 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        txn.execute(sql, keyvalues.values())
+        txn.execute(sql, list(keyvalues.values()))
         if txn.rowcount == 0:
             raise StoreError(404, "No row found")
         if txn.rowcount > 1:
@@ -856,7 +911,7 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        return txn.execute(sql, keyvalues.values())
+        return txn.execute(sql, list(keyvalues.values()))
 
     def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
         return self.runInteraction(
@@ -888,7 +943,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.iteritems():
+        for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -928,7 +983,7 @@ class SQLBaseStore(object):
         txn.close()
 
         if cache:
-            min_val = min(cache.itervalues())
+            min_val = min(itervalues(cache))
         else:
             min_val = max_value
 
@@ -951,7 +1006,8 @@ class SQLBaseStore(object):
             # __exit__ called after the transaction finishes.
             ctx = self._cache_id_gen.get_next()
             stream_id = ctx.__enter__()
-            txn.call_finally(ctx.__exit__, None, None, None)
+            txn.call_on_exception(ctx.__exit__, None, None, None)
+            txn.call_after(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             self._simple_insert_txn(
@@ -1042,7 +1098,7 @@ class SQLBaseStore(object):
                 " AND ".join("%s = ?" % (k,) for k in keyvalues),
                 " ? ASC LIMIT ? OFFSET ?"
             )
-            txn.execute(sql, keyvalues.values() + pagevalues)
+            txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
         else:
             sql = "SELECT %s FROM %s ORDER BY %s" % (
                 ", ".join(retcols),
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index aa84ffc2b0..bbc3355c73 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from twisted.internet import defer
+import abc
+import logging
 
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from canonicaljson import json
 
-import ujson as json
-import logging
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
 
 
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_account_data_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        account_max = self.get_max_account_data_stream_id()
+        self._account_data_stream_cache = StreamChangeCache(
+            "AccountDataAndTagsChangeCache", account_max,
+        )
+
+        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+    @abc.abstractmethod
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream ID for account data stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
 
     @cached()
     def get_account_data_for_user(self, user_id):
@@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
-    @cachedInlineCallbacks(num_args=2)
+    @cachedInlineCallbacks(num_args=2, max_entries=5000)
     def get_global_account_data_by_type_for_user(self, data_type, user_id):
         """
         Returns:
@@ -85,25 +114,7 @@ class AccountDataStore(SQLBaseStore):
         else:
             defer.returnValue(None)
 
-    @cachedList(cached_method_name="get_global_account_data_by_type_for_user",
-                num_args=2, list_name="user_ids", inlineCallbacks=True)
-    def get_global_account_data_by_type_for_users(self, data_type, user_ids):
-        rows = yield self._simple_select_many_batch(
-            table="account_data",
-            column="user_id",
-            iterable=user_ids,
-            keyvalues={
-                "account_data_type": data_type,
-            },
-            retcols=("user_id", "content",),
-            desc="get_global_account_data_by_type_for_users",
-        )
-
-        defer.returnValue({
-            row["user_id"]: json.loads(row["content"]) if row["content"] else None
-            for row in rows
-        })
-
+    @cached(num_args=2)
     def get_account_data_for_room(self, user_id, room_id):
         """Get all the client account_data for a user for a room.
 
@@ -127,6 +138,38 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
+    @cached(num_args=3, max_entries=5000)
+    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+        """Get the client account_data of given type for a user for a room.
+
+        Args:
+            user_id(str): The user to get the account_data for.
+            room_id(str): The room to get the account_data for.
+            account_data_type (str): The account data type to get.
+        Returns:
+            A deferred of the room account_data for that type, or None if
+            there isn't any set.
+        """
+        def get_account_data_for_room_and_type_txn(txn):
+            content_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="room_account_data",
+                keyvalues={
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "account_data_type": account_data_type,
+                },
+                retcol="content",
+                allow_none=True
+            )
+
+            return json.loads(content_json) if content_json else None
+
+        return self.runInteraction(
+            "get_account_data_for_room_and_type",
+            get_account_data_for_room_and_type_txn,
+        )
+
     def get_all_updated_account_data(self, last_global_id, last_room_id,
                                      current_id, limit):
         """Get all the client account_data that has changed on the server
@@ -209,6 +252,36 @@ class AccountDataStore(SQLBaseStore):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+            "m.ignored_user_list", ignorer_user_id,
+            on_invalidate=cache_context.invalidate,
+        )
+        if not ignored_account_data:
+            defer.returnValue(False)
+
+        defer.returnValue(
+            ignored_user_id in ignored_account_data.get("ignored_users", {})
+        )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    def __init__(self, db_conn, hs):
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        super(AccountDataStore, self).__init__(db_conn, hs)
+
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream id for the private user data stream
+
+        Returns:
+            A deferred int.
+        """
+        return self._account_data_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
         """Add some account_data to a room for a user.
@@ -222,9 +295,12 @@ class AccountDataStore(SQLBaseStore):
         """
         content_json = json.dumps(content)
 
-        def add_account_data_txn(txn, next_id):
-            self._simple_upsert_txn(
-                txn,
+        with self._account_data_id_gen.get_next() as next_id:
+            # no need to lock here as room_account_data has a unique constraint
+            # on (user_id, room_id, account_data_type) so _simple_upsert will
+            # retry if there is a conflict.
+            yield self._simple_upsert(
+                desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
                     "user_id": user_id,
@@ -234,18 +310,23 @@ class AccountDataStore(SQLBaseStore):
                 values={
                     "stream_id": next_id,
                     "content": content_json,
-                }
-            )
-            txn.call_after(
-                self._account_data_stream_cache.entity_has_changed,
-                user_id, next_id,
+                },
+                lock=False,
             )
-            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
-            self._update_max_stream_id(txn, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "add_room_account_data", add_account_data_txn, next_id
+            # it's theoretically possible for the above to succeed and the
+            # below to fail - in which case we might reuse a stream id on
+            # restart, and the above update might not get propagated. That
+            # doesn't sound any worse than the whole update getting lost,
+            # which is what would happen if we combined the two into one
+            # transaction.
+            yield self._update_max_stream_id(next_id)
+
+            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_account_data_for_room.invalidate((user_id, room_id,))
+            self.get_account_data_for_room_and_type.prefill(
+                (user_id, room_id, account_data_type,), content,
             )
 
         result = self._account_data_id_gen.get_current_token()
@@ -263,9 +344,12 @@ class AccountDataStore(SQLBaseStore):
         """
         content_json = json.dumps(content)
 
-        def add_account_data_txn(txn, next_id):
-            self._simple_upsert_txn(
-                txn,
+        with self._account_data_id_gen.get_next() as next_id:
+            # no need to lock here as account_data has a unique constraint on
+            # (user_id, account_data_type) so _simple_upsert will retry if
+            # there is a conflict.
+            yield self._simple_upsert(
+                desc="add_user_account_data",
                 table="account_data",
                 keyvalues={
                     "user_id": user_id,
@@ -274,37 +358,43 @@ class AccountDataStore(SQLBaseStore):
                 values={
                     "stream_id": next_id,
                     "content": content_json,
-                }
+                },
+                lock=False,
             )
-            txn.call_after(
-                self._account_data_stream_cache.entity_has_changed,
+
+            # it's theoretically possible for the above to succeed and the
+            # below to fail - in which case we might reuse a stream id on
+            # restart, and the above update might not get propagated. That
+            # doesn't sound any worse than the whole update getting lost,
+            # which is what would happen if we combined the two into one
+            # transaction.
+            yield self._update_max_stream_id(next_id)
+
+            self._account_data_stream_cache.entity_has_changed(
                 user_id, next_id,
             )
-            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
-            txn.call_after(
-                self.get_global_account_data_by_type_for_user.invalidate,
+            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_by_type_for_user.invalidate(
                 (account_data_type, user_id,)
             )
-            self._update_max_stream_id(txn, next_id)
-
-        with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "add_user_account_data", add_account_data_txn, next_id
-            )
 
         result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
-    def _update_max_stream_id(self, txn, next_id):
+    def _update_max_stream_id(self, next_id):
         """Update the max stream_id
 
         Args:
-            txn: The database cursor
             next_id(int): The the revision to advance to.
         """
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
+        def _update(txn):
+            update_max_id_sql = (
+                "UPDATE account_data_max_stream_id"
+                " SET stream_id = ?"
+                " WHERE stream_id < ?"
+            )
+            txn.execute(update_max_id_sql, (next_id, next_id))
+        return self.runInteraction(
+            "update_account_data_max_stream_id",
+            _update,
         )
-        txn.execute(update_max_id_sql, (next_id, next_id))
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index c63935cb07..9f12b360bc 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,15 +15,16 @@
 # limitations under the License.
 import logging
 import re
-import simplejson as json
+
+from canonicaljson import json
+
 from twisted.internet import defer
 
-from synapse.api.constants import Membership
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
-from synapse.storage.roommember import RoomsForUser
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
 
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -46,17 +48,16 @@ def _make_exclusive_regex(services_cache):
     return exclusive_user_regex
 
 
-class ApplicationServiceStore(SQLBaseStore):
-
-    def __init__(self, hs):
-        super(ApplicationServiceStore, self).__init__(hs)
-        self.hostname = hs.hostname
+class ApplicationServiceWorkerStore(SQLBaseStore):
+    def __init__(self, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname,
             hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
+        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+
     def get_app_services(self):
         return self.services_cache
 
@@ -99,83 +100,30 @@ class ApplicationServiceStore(SQLBaseStore):
                 return service
         return None
 
-    def get_app_service_rooms(self, service):
-        """Get a list of RoomsForUser for this application service.
-
-        Application services may be "interested" in lots of rooms depending on
-        the room ID, the room aliases, or the members in the room. This function
-        takes all of these into account and returns a list of RoomsForUser which
-        represent the entire list of room IDs that this application service
-        wants to know about.
+    def get_app_service_by_id(self, as_id):
+        """Get the application service with the given appservice ID.
 
         Args:
-            service: The application service to get a room list for.
+            as_id (str): The application service ID.
         Returns:
-            A list of RoomsForUser.
+            synapse.appservice.ApplicationService or None.
         """
-        return self.runInteraction(
-            "get_app_service_rooms",
-            self._get_app_service_rooms_txn,
-            service,
-        )
-
-    def _get_app_service_rooms_txn(self, txn, service):
-        # get all rooms matching the room ID regex.
-        room_entries = self._simple_select_list_txn(
-            txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
-        )
-        matching_room_list = set([
-            r["room_id"] for r in room_entries if
-            service.is_interested_in_room(r["room_id"])
-        ])
-
-        # resolve room IDs for matching room alias regex.
-        room_alias_mappings = self._simple_select_list_txn(
-            txn=txn, table="room_aliases", keyvalues=None,
-            retcols=["room_id", "room_alias"]
-        )
-        matching_room_list |= set([
-            r["room_id"] for r in room_alias_mappings if
-            service.is_interested_in_alias(r["room_alias"])
-        ])
-
-        # get all rooms for every user for this AS. This is scoped to users on
-        # this HS only.
-        user_list = self._simple_select_list_txn(
-            txn=txn, table="users", keyvalues=None, retcols=["name"]
-        )
-        user_list = [
-            u["name"] for u in user_list if
-            service.is_interested_in_user(u["name"])
-        ]
-        rooms_for_user_matching_user_id = set()  # RoomsForUser list
-        for user_id in user_list:
-            # FIXME: This assumes this store is linked with RoomMemberStore :(
-            rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
-                txn=txn,
-                user_id=user_id,
-                membership_list=[Membership.JOIN]
-            )
-            rooms_for_user_matching_user_id |= set(rooms_for_user)
-
-        # make RoomsForUser tuples for room ids and aliases which are not in the
-        # main rooms_for_user_list - e.g. they are rooms which do not have AS
-        # registered users in it.
-        known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
-        missing_rooms_for_user = [
-            RoomsForUser(r, service.sender, "join") for r in
-            matching_room_list if r not in known_room_ids
-        ]
-        rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
-
-        return rooms_for_user_matching_user_id
+        for service in self.services_cache:
+            if service.id == as_id:
+                return service
+        return None
 
 
-class ApplicationServiceTransactionStore(SQLBaseStore):
+class ApplicationServiceStore(ApplicationServiceWorkerStore):
+    # This is currently empty due to there not being any AS storage functions
+    # that can't be run on the workers. Since this may change in future, and
+    # to keep consistency with the other stores, we keep this empty class for
+    # now.
+    pass
 
-    def __init__(self, hs):
-        super(ApplicationServiceTransactionStore, self).__init__(hs)
 
+class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
+                                               EventsWorkerStore):
     @defer.inlineCallbacks
     def get_appservices_by_state(self, state):
         """Get a list of application services based on their state.
@@ -420,3 +368,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
         events = yield self._get_events(event_ids)
 
         defer.returnValue((upper_bound, events))
+
+
+class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
+    # This is currently empty due to there not being any AS storage functions
+    # that can't be run on the workers. Since this may change in future, and
+    # to keep consistency with the other stores, we keep this empty class for
+    # now.
+    pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 7157fb1dfb..5fe1ca2de7 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,15 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import synapse.util.async
 
-from ._base import SQLBaseStore
-from . import engines
+import logging
+
+from canonicaljson import json
 
 from twisted.internet import defer
 
-import ujson as json
-import logging
+from synapse.metrics.background_process_metrics import run_as_background_process
+
+from . import engines
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -80,25 +82,30 @@ class BackgroundUpdateStore(SQLBaseStore):
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, hs):
-        super(BackgroundUpdateStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(BackgroundUpdateStore, self).__init__(db_conn, hs)
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
+        self._all_done = False
 
-    @defer.inlineCallbacks
     def start_doing_background_updates(self):
-        logger.info("Starting background schema updates")
+        run_as_background_process(
+            "background_updates", self._run_background_updates,
+        )
 
+    @defer.inlineCallbacks
+    def _run_background_updates(self):
+        logger.info("Starting background schema updates")
         while True:
-            yield synapse.util.async.sleep(
+            yield self.hs.get_clock().sleep(
                 self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
 
             try:
                 result = yield self.do_next_background_update(
                     self.BACKGROUND_UPDATE_DURATION_MS
                 )
-            except:
+            except Exception:
                 logger.exception("Error doing update")
             else:
                 if result is None:
@@ -106,9 +113,41 @@ class BackgroundUpdateStore(SQLBaseStore):
                         "No more background updates to do."
                         " Unscheduling background update task."
                     )
+                    self._all_done = True
                     defer.returnValue(None)
 
     @defer.inlineCallbacks
+    def has_completed_background_updates(self):
+        """Check if all the background updates have completed
+
+        Returns:
+            Deferred[bool]: True if all background updates have completed
+        """
+        # if we've previously determined that there is nothing left to do, that
+        # is easy
+        if self._all_done:
+            defer.returnValue(True)
+
+        # obviously, if we have things in our queue, we're not done.
+        if self._background_update_queue:
+            defer.returnValue(False)
+
+        # otherwise, check if there are updates to be run. This is important,
+        # as we may be running on a worker which doesn't perform the bg updates
+        # itself, but still wants to wait for them to happen.
+        updates = yield self._simple_select_onecol(
+            "background_updates",
+            keyvalues=None,
+            retcol="1",
+            desc="check_background_updates",
+        )
+        if not updates:
+            self._all_done = True
+            defer.returnValue(True)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
     def do_next_background_update(self, desired_duration_ms):
         """Does some amount of work on the next queued background update
 
@@ -209,6 +248,25 @@ class BackgroundUpdateStore(SQLBaseStore):
         """
         self._background_update_handlers[update_name] = update_handler
 
+    def register_noop_background_update(self, update_name):
+        """Register a noop handler for a background update.
+
+        This is useful when we previously did a background update, but no
+        longer wish to do the update. In this case the background update should
+        be removed from the schema delta files, but there may still be some
+        users who have the background update queued, so this method should
+        also be called to clear the update.
+
+        Args:
+            update_name (str): Name of update
+        """
+        @defer.inlineCallbacks
+        def noop_update(progress, batch_size):
+            yield self._end_background_update(update_name)
+            defer.returnValue(1)
+
+        self.register_background_update_handler(update_name, noop_update)
+
     def register_background_index_update(self, update_name, index_name,
                                          table, columns, where_clause=None,
                                          unique=False,
@@ -269,7 +327,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             # Sqlite doesn't support concurrent creation of indexes.
             #
             # We don't use partial indices on SQLite as it wasn't introduced
-            # until 3.8, and wheezy has 3.7
+            # until 3.8, and wheezy and CentOS 7 have 3.7
             #
             # We assume that sqlite doesn't give us invalid indices; however
             # we may still end up with the index existing but the
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index fc468ea185..77ae10da3d 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,13 +15,15 @@
 
 import logging
 
-from twisted.internet import defer, reactor
+from six import iteritems
 
-from ._base import Cache
-from . import background_updates
+from twisted.internet import defer
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.caches import CACHE_SIZE_FACTOR
 
+from . import background_updates
+from ._base import Cache
 
 logger = logging.getLogger(__name__)
 
@@ -32,14 +34,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
 
 
 class ClientIpStore(background_updates.BackgroundUpdateStore):
-    def __init__(self, hs):
+    def __init__(self, db_conn, hs):
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
             keylen=4,
             max_entries=50000 * CACHE_SIZE_FACTOR,
         )
 
-        super(ClientIpStore, self).__init__(hs)
+        super(ClientIpStore, self).__init__(db_conn, hs)
 
         self.register_background_index_update(
             "user_ips_device_index",
@@ -48,17 +50,35 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             columns=["user_id", "device_id", "last_seen"],
         )
 
+        self.register_background_index_update(
+            "user_ips_last_seen_index",
+            index_name="user_ips_last_seen",
+            table="user_ips",
+            columns=["user_id", "last_seen"],
+        )
+
+        self.register_background_index_update(
+            "user_ips_last_seen_only_index",
+            index_name="user_ips_last_seen_only",
+            table="user_ips",
+            columns=["last_seen"],
+        )
+
         # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
         self._batch_row_update = {}
 
         self._client_ip_looper = self._clock.looping_call(
             self._update_client_ips_batch, 5 * 1000
         )
-        reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+        self.hs.get_reactor().addSystemEventTrigger(
+            "before", "shutdown", self._update_client_ips_batch
+        )
 
-    def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
-        now = int(self._clock.time_msec())
-        key = (user.to_string(), access_token, ip)
+    def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
+                         now=None):
+        if not now:
+            now = int(self._clock.time_msec())
+        key = (user_id, access_token, ip)
 
         try:
             last_seen = self.client_ip_last_seen.get(key)
@@ -74,16 +94,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         self._batch_row_update[key] = (user_agent, device_id, now)
 
     def _update_client_ips_batch(self):
-        to_update = self._batch_row_update
-        self._batch_row_update = {}
-        return self.runInteraction(
-            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+        def update():
+            to_update = self._batch_row_update
+            self._batch_row_update = {}
+            return self.runInteraction(
+                "_update_client_ips_batch", self._update_client_ips_batch_txn,
+                to_update,
+            )
+
+        run_as_background_process(
+            "update_client_ips", update,
         )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
         self.database_engine.lock_table(txn, "user_ips")
 
-        for entry in to_update.iteritems():
+        for entry in iteritems(to_update):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
             self._simple_upsert_txn(
@@ -215,5 +241,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 "user_agent": user_agent,
                 "last_seen": last_seen,
             }
-            for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+            for (access_token, ip), (user_agent, last_seen) in iteritems(results)
         ))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 0b62b493d5..73646da025 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,14 +14,14 @@
 # limitations under the License.
 
 import logging
-import ujson
 
-from twisted.internet import defer
+from canonicaljson import json
 
-from .background_updates import BackgroundUpdateStore
+from twisted.internet import defer
 
 from synapse.util.caches.expiringcache import ExpiringCache
 
+from .background_updates import BackgroundUpdateStore
 
 logger = logging.getLogger(__name__)
 
@@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
 class DeviceInboxStore(BackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, hs):
-        super(DeviceInboxStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(db_conn, hs)
 
         self.register_background_index_update(
             "device_inbox_stream_index",
@@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             )
             rows = []
             for destination, edu in remote_messages_by_destination.items():
-                edu_json = ujson.dumps(edu)
+                edu_json = json.dumps(edu)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
@@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                     " WHERE user_id = ?"
                 )
                 txn.execute(sql, (user_id,))
-                message_json = ujson.dumps(messages_by_device["*"])
+                message_json = json.dumps(messages_by_device["*"])
                 for row in txn:
                     # Add the message for all devices for this user on this
                     # server.
@@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                     # Only insert into the local inbox if the device exists on
                     # this server
                     device = row[0]
-                    message_json = ujson.dumps(messages_by_device[device])
+                    message_json = json.dumps(messages_by_device[device])
                     messages_json_for_user[device] = message_json
 
             if messages_json_for_user:
@@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             messages = []
             for row in txn:
                 stream_pos = row[0]
-                messages.append(ujson.loads(row[1]))
+                messages.append(json.loads(row[1]))
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return (messages, stream_pos)
@@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             messages = []
             for row in txn:
                 stream_pos = row[0]
-                messages.append(ujson.loads(row[1]))
+                messages.append(json.loads(row[1]))
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bb27fd1f70..cc3cdf2ebc 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,21 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import ujson as json
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Cache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
+from ._base import Cache, SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(DeviceStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(DeviceStore, self).__init__(db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
@@ -245,17 +248,31 @@ class DeviceStore(SQLBaseStore):
 
     def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
                                                    content, stream_id):
-        self._simple_upsert_txn(
-            txn,
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "content": json.dumps(content),
-            }
-        )
+        if content.get("deleted"):
+            self._simple_delete_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+            )
+
+            txn.call_after(
+                self.device_id_exists_cache.invalidate, (user_id, device_id,)
+            )
+        else:
+            self._simple_upsert_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "content": json.dumps(content),
+                }
+            )
 
         txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -360,10 +377,10 @@ class DeviceStore(SQLBaseStore):
             return (now_stream_id, [])
 
         if len(query_map) >= 20:
-            now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+            now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True
+            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
         )
 
         prev_sent_id_sql = """
@@ -373,13 +390,13 @@ class DeviceStore(SQLBaseStore):
         """
 
         results = []
-        for user_id, user_devices in devices.iteritems():
+        for user_id, user_devices in iteritems(devices):
             # The prev_id for the first row is always the last row before
             # `from_stream_id`
             txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
             rows = txn.fetchall()
             prev_id = rows[0][0]
-            for device_id, device in user_devices.iteritems():
+            for device_id, device in iteritems(user_devices):
                 stream_id = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -390,12 +407,15 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = json.loads(key_json)
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
+                if device is not None:
+                    key_json = device.get("key_json", None)
+                    if key_json:
+                        result["keys"] = json.loads(key_json)
+                    device_display_name = device.get("device_display_name", None)
+                    if device_display_name:
+                        result["device_display_name"] = device_display_name
+                else:
+                    result["deleted"] = True
 
                 results.append(result)
 
@@ -483,7 +503,7 @@ class DeviceStore(SQLBaseStore):
         if devices:
             user_devices = devices[user_id]
             results = []
-            for device_id, device in user_devices.iteritems():
+            for device_id, device in iteritems(user_devices):
                 result = {
                     "device_id": device_id,
                 }
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 79e7c540ad..808194236a 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,15 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-
-from synapse.api.errors import SynapseError
+from collections import namedtuple
 
 from twisted.internet import defer
 
-from collections import namedtuple
+from synapse.api.errors import SynapseError
+from synapse.util.caches.descriptors import cached
 
+from ._base import SQLBaseStore
 
 RoomAliasMapping = namedtuple(
     "RoomAliasMapping",
@@ -29,8 +28,7 @@ RoomAliasMapping = namedtuple(
 )
 
 
-class DirectoryStore(SQLBaseStore):
-
+class DirectoryWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_association_from_room_alias(self, room_alias):
         """ Get's the room_id and server list for a given room_alias
@@ -69,6 +67,28 @@ class DirectoryStore(SQLBaseStore):
             RoomAliasMapping(room_id, room_alias.to_string(), servers)
         )
 
+    def get_room_alias_creator(self, room_alias):
+        return self._simple_select_one_onecol(
+            table="room_aliases",
+            keyvalues={
+                "room_alias": room_alias,
+            },
+            retcol="creator",
+            desc="get_room_alias_creator",
+            allow_none=True
+        )
+
+    @cached(max_entries=5000)
+    def get_aliases_for_room(self, room_id):
+        return self._simple_select_onecol(
+            "room_aliases",
+            {"room_id": room_id},
+            "room_alias",
+            desc="get_aliases_for_room",
+        )
+
+
+class DirectoryStore(DirectoryWorkerStore):
     @defer.inlineCallbacks
     def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
         """ Creates an associatin between  a room alias and room_id/servers
@@ -116,17 +136,6 @@ class DirectoryStore(SQLBaseStore):
             )
         defer.returnValue(ret)
 
-    def get_room_alias_creator(self, room_alias):
-        return self._simple_select_one_onecol(
-            table="room_aliases",
-            keyvalues={
-                "room_alias": room_alias,
-            },
-            retcol="creator",
-            desc="get_room_alias_creator",
-            allow_none=True
-        )
-
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
         room_id = yield self.runInteraction(
@@ -135,7 +144,6 @@ class DirectoryStore(SQLBaseStore):
             room_alias,
         )
 
-        self.get_aliases_for_room.invalidate((room_id,))
         defer.returnValue(room_id)
 
     def _delete_room_alias_txn(self, txn, room_alias):
@@ -160,17 +168,12 @@ class DirectoryStore(SQLBaseStore):
             (room_alias.to_string(),)
         )
 
-        return room_id
-
-    @cached(max_entries=5000)
-    def get_aliases_for_room(self, room_id):
-        return self._simple_select_onecol(
-            "room_aliases",
-            {"room_id": room_id},
-            "room_alias",
-            desc="get_aliases_for_room",
+        self._invalidate_cache_and_stream(
+            txn, self.get_aliases_for_room, (room_id,)
         )
 
+        return room_id
+
     def update_aliases_for_room(self, old_room_id, new_room_id, creator):
         def _update_aliases_for_room_txn(txn):
             sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2cebb203c6..523b4360c3 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,13 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
 from twisted.internet import defer
 
 from synapse.util.caches.descriptors import cached
 
-from canonicaljson import encode_canonical_json
-import ujson as json
-
 from ._base import SQLBaseStore
 
 
@@ -63,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_e2e_device_keys(self, query_list, include_all_devices=False):
+    def get_e2e_device_keys(
+        self, query_list, include_all_devices=False,
+        include_deleted_devices=False,
+    ):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
             include_all_devices (bool): whether to include entries for devices
                 that don't have device keys
+            include_deleted_devices (bool): whether to include null entries for
+                devices which no longer exist (but were in the query_list).
+                This option only takes effect if include_all_devices is true.
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -78,19 +85,28 @@ class EndToEndKeyStore(SQLBaseStore):
 
         results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
-            query_list, include_all_devices,
+            query_list, include_all_devices, include_deleted_devices,
         )
 
-        for user_id, device_keys in results.iteritems():
-            for device_id, device_info in device_keys.iteritems():
+        for user_id, device_keys in iteritems(results):
+            for device_id, device_info in iteritems(device_keys):
                 device_info["keys"] = json.loads(device_info.pop("key_json"))
 
         defer.returnValue(results)
 
-    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
+    def _get_e2e_device_keys_txn(
+        self, txn, query_list, include_all_devices=False,
+        include_deleted_devices=False,
+    ):
         query_clauses = []
         query_params = []
 
+        if include_all_devices is False:
+            include_deleted_devices = False
+
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
         for (user_id, device_id) in query_list:
             query_clause = "user_id = ?"
             query_params.append(user_id)
@@ -118,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
 
         result = {}
         for row in rows:
+            if include_deleted_devices:
+                deleted_devices.remove((row["user_id"], row["device_id"]))
             result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                result.setdefault(user_id, {})[device_id] = None
+
         return result
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 338b495611..e2f9de8451 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -13,13 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import importlib
+import platform
+
 from ._base import IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite3 import Sqlite3Engine
 
-import importlib
-
-
 SUPPORTED_MODULE = {
     "sqlite3": Sqlite3Engine,
     "psycopg2": PostgresEngine,
@@ -31,6 +31,10 @@ def create_engine(database_config):
     engine_class = SUPPORTED_MODULE.get(name, None)
 
     if engine_class:
+        # pypy requires psycopg2cffi rather than psycopg2
+        if (name == "psycopg2" and
+                platform.python_implementation() == "PyPy"):
+            name = "psycopg2cffi"
         module = importlib.import_module(name)
         return engine_class(module, database_config)
 
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
 
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        txn.execute("SELECT nextval('state_group_id_seq')")
+        return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..19949fc474 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -13,9 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.prepare_database import prepare_database
-
 import struct
+import threading
+
+from synapse.storage.prepare_database import prepare_database
 
 
 class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        # The current max state_group, or None if we haven't looked
+        # in the DB yet.
+        self._current_state_group_id = None
+        self._current_state_group_id_lock = threading.Lock()
+
     def check_database(self, txn):
         pass
 
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
     def lock_table(self, txn, table):
         return
 
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._current_state_group_id_lock:
+            if self._current_state_group_id is None:
+                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+                self._current_state_group_id = txn.fetchone()[0]
+
+            self._current_state_group_id += 1
+            return self._current_state_group_id
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index e8133de2fa..8d366d1b91 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,45 +12,27 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import random
 
-from twisted.internet import defer
+from six.moves import range
+from six.moves.queue import Empty, PriorityQueue
 
-from ._base import SQLBaseStore
-from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
 from unpaddedbase64 import encode_base64
 
-import logging
-from Queue import PriorityQueue, Empty
+from twisted.internet import defer
 
+from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
 
-class EventFederationStore(SQLBaseStore):
-    """ Responsible for storing and serving up the various graphs associated
-    with an event. Including the main event graph and the auth chains for an
-    event.
-
-    Also has methods for getting the front (latest) and back (oldest) edges
-    of the event graphs. These are used to generate the parents for new events
-    and backfilling from another server respectively.
-    """
-
-    EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
-
-    def __init__(self, hs):
-        super(EventFederationStore, self).__init__(hs)
-
-        self.register_background_update_handler(
-            self.EVENT_AUTH_STATE_ONLY,
-            self._background_delete_non_state_event_auth,
-        )
-
-        hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
-        )
-
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
+                                 SQLBaseStore):
     def get_auth_chain(self, event_ids, include_given=False):
         """Get auth events for given event_ids. The events *must* be state events.
 
@@ -97,7 +79,7 @@ class EventFederationStore(SQLBaseStore):
             front_list = list(front)
             chunks = [
                 front_list[x:x + 100]
-                for x in xrange(0, len(front), 100)
+                for x in range(0, len(front), 100)
             ]
             for chunk in chunks:
                 txn.execute(
@@ -152,7 +134,47 @@ class EventFederationStore(SQLBaseStore):
             retcol="event_id",
         )
 
+    @defer.inlineCallbacks
+    def get_prev_events_for_room(self, room_id):
+        """
+        Gets a subset of the current forward extremities in the given room.
+
+        Limits the result to 10 extremities, so that we can avoid creating
+        events which refer to hundreds of prev_events.
+
+        Args:
+            room_id (str): room_id
+
+        Returns:
+            Deferred[list[(str, dict[str, str], int)]]
+                for each event, a tuple of (event_id, hashes, depth)
+                where *hashes* is a map from algorithm to hash.
+        """
+        res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
+        if len(res) > 10:
+            # Sort by reverse depth, so we point to the most recent.
+            res.sort(key=lambda a: -a[2])
+
+            # we use half of the limit for the actual most recent events, and
+            # the other half to randomly point to some of the older events, to
+            # make sure that we don't completely ignore the older events.
+            res = res[0:5] + random.sample(res[5:], 5)
+
+        defer.returnValue(res)
+
     def get_latest_event_ids_and_hashes_in_room(self, room_id):
+        """
+        Gets the current forward extremities in the given room
+
+        Args:
+            room_id (str): room_id
+
+        Returns:
+            Deferred[list[(str, dict[str, str], int)]]
+                for each event, a tuple of (event_id, hashes, depth)
+                where *hashes* is a map from algorithm to hash.
+        """
+
         return self.runInteraction(
             "get_latest_event_ids_and_hashes_in_room",
             self._get_latest_event_ids_and_hashes_in_room,
@@ -201,22 +223,6 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_max_depth_of_events(self, event_ids):
-        sql = (
-            "SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
-        ) % (",".join(["?"] * len(event_ids)),)
-
-        rows = yield self._execute(
-            "get_max_depth_of_events", None,
-            sql, *event_ids
-        )
-
-        if rows:
-            defer.returnValue(rows[0][0])
-        else:
-            defer.returnValue(1)
-
     def _get_min_depth_interaction(self, txn, room_id):
         min_depth = self._simple_select_one_onecol_txn(
             txn,
@@ -228,88 +234,6 @@ class EventFederationStore(SQLBaseStore):
 
         return int(min_depth) if min_depth is not None else None
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
-        min_depth = self._get_min_depth_interaction(txn, room_id)
-
-        if min_depth and depth >= min_depth:
-            return
-
-        self._simple_upsert_txn(
-            txn,
-            table="room_depth",
-            keyvalues={
-                "room_id": room_id,
-            },
-            values={
-                "min_depth": depth,
-            },
-        )
-
-    def _handle_mult_prev_events(self, txn, events):
-        """
-        For the given event, update the event edges table and forward and
-        backward extremities tables.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="event_edges",
-            values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
-                for ev in events
-                for e_id, _ in ev.prev_events
-            ],
-        )
-
-        self._update_backward_extremeties(txn, events)
-
-    def _update_backward_extremeties(self, txn, events):
-        """Updates the event_backward_extremities tables based on the new/updated
-        events being persisted.
-
-        This is called for new events *and* for events that were outliers, but
-        are now being persisted as non-outliers.
-
-        Forward extremities are handled when we first start persisting the events.
-        """
-        events_by_room = {}
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
-        query = (
-            "INSERT INTO event_backward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-            " )"
-            " AND NOT EXISTS ("
-            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            " AND outlier = ?"
-            " )"
-        )
-
-        txn.executemany(query, [
-            (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
-            for ev in events for e_id, _ in ev.prev_events
-            if not ev.internal_metadata.is_outlier()
-        ])
-
-        query = (
-            "DELETE FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-        )
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id) for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ]
-        )
-
     def get_forward_extremeties_for_room(self, room_id, stream_ordering):
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
@@ -371,28 +295,6 @@ class EventFederationStore(SQLBaseStore):
             get_forward_extremeties_for_room_txn
         )
 
-    def _delete_old_forward_extrem_cache(self):
-        def _delete_old_forward_extrem_cache_txn(txn):
-            # Delete entries older than a month, while making sure we don't delete
-            # the only entries for a room.
-            sql = ("""
-                DELETE FROM stream_ordering_to_exterm
-                WHERE
-                room_id IN (
-                    SELECT room_id
-                    FROM stream_ordering_to_exterm
-                    WHERE stream_ordering > ?
-                ) AND stream_ordering < ?
-            """)
-            txn.execute(
-                sql,
-                (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
-            )
-        return self.runInteraction(
-            "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn
-        )
-
     def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
@@ -522,6 +424,135 @@ class EventFederationStore(SQLBaseStore):
 
         return event_results
 
+
+class EventFederationStore(EventFederationWorkerStore):
+    """ Responsible for storing and serving up the various graphs associated
+    with an event. Including the main event graph and the auth chains for an
+    event.
+
+    Also has methods for getting the front (latest) and back (oldest) edges
+    of the event graphs. These are used to generate the parents for new events
+    and backfilling from another server respectively.
+    """
+
+    EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
+    def __init__(self, db_conn, hs):
+        super(EventFederationStore, self).__init__(db_conn, hs)
+
+        self.register_background_update_handler(
+            self.EVENT_AUTH_STATE_ONLY,
+            self._background_delete_non_state_event_auth,
+        )
+
+        hs.get_clock().looping_call(
+            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+        )
+
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self._get_min_depth_interaction(txn, room_id)
+
+        if min_depth and depth >= min_depth:
+            return
+
+        self._simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={
+                "room_id": room_id,
+            },
+            values={
+                "min_depth": depth,
+            },
+        )
+
+    def _handle_mult_prev_events(self, txn, events):
+        """
+        For the given event, update the event edges table and forward and
+        backward extremities tables.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="event_edges",
+            values=[
+                {
+                    "event_id": ev.event_id,
+                    "prev_event_id": e_id,
+                    "room_id": ev.room_id,
+                    "is_state": False,
+                }
+                for ev in events
+                for e_id, _ in ev.prev_events
+            ],
+        )
+
+        self._update_backward_extremeties(txn, events)
+
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
+        events being persisted.
+
+        This is called for new events *and* for events that were outliers, but
+        are now being persisted as non-outliers.
+
+        Forward extremities are handled when we first start persisting the events.
+        """
+        events_by_room = {}
+        for ev in events:
+            events_by_room.setdefault(ev.room_id, []).append(ev)
+
+        query = (
+            "INSERT INTO event_backward_extremities (event_id, room_id)"
+            " SELECT ?, ? WHERE NOT EXISTS ("
+            " SELECT 1 FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+            " )"
+            " AND NOT EXISTS ("
+            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+            " AND outlier = ?"
+            " )"
+        )
+
+        txn.executemany(query, [
+            (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+            for ev in events for e_id, _ in ev.prev_events
+            if not ev.internal_metadata.is_outlier()
+        ])
+
+        query = (
+            "DELETE FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+        )
+        txn.executemany(
+            query,
+            [
+                (ev.event_id, ev.room_id) for ev in events
+                if not ev.internal_metadata.is_outlier()
+            ]
+        )
+
+    def _delete_old_forward_extrem_cache(self):
+        def _delete_old_forward_extrem_cache_txn(txn):
+            # Delete entries older than a month, while making sure we don't delete
+            # the only entries for a room.
+            sql = ("""
+                DELETE FROM stream_ordering_to_exterm
+                WHERE
+                room_id IN (
+                    SELECT room_id
+                    FROM stream_ordering_to_exterm
+                    WHERE stream_ordering > ?
+                ) AND stream_ordering < ?
+            """)
+            txn.execute(
+                sql,
+                (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+            )
+        return self.runInteraction(
+            "_delete_old_forward_extrem_cache",
+            _delete_old_forward_extrem_cache_txn
+        )
+
     def clean_room_for_join(self, room_id):
         return self.runInteraction(
             "clean_room_for_join",
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index d6d8723b4a..29b511ae5e 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,15 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+import logging
+
+from six import iteritems
+
+from canonicaljson import json
+
 from twisted.internet import defer
-from synapse.util.async import sleep
-from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.types import RoomStreamToken
-from .stream import lower_bound
 
-import logging
-import ujson as json
+from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
 
@@ -62,59 +64,29 @@ def _deserialize_action(actions, is_highlight):
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsStore(SQLBaseStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
-    def __init__(self, hs):
-        super(EventPushActionsStore, self).__init__(hs)
+class EventPushActionsWorkerStore(SQLBaseStore):
+    def __init__(self, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
 
-        self.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
+        # These get correctly set by _find_stream_orderings_for_times_txn
+        self.stream_ordering_month_ago = None
+        self.stream_ordering_day_ago = None
 
-        self.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1"
+        cur = LoggingTransaction(
+            db_conn.cursor(),
+            name="_find_stream_orderings_for_times_txn",
+            database_engine=self.database_engine,
+            after_callbacks=[],
+            exception_callbacks=[],
         )
+        self._find_stream_orderings_for_times_txn(cur)
+        cur.close()
 
-        self._doing_notif_rotation = False
-        self._rotate_notif_loop = self._clock.looping_call(
-            self._rotate_notifs, 30 * 60 * 1000
+        self.find_stream_orderings_looping_call = self._clock.looping_call(
+            self._find_stream_orderings_for_times, 10 * 60 * 1000
         )
-
-    def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
-        """
-        Args:
-            event: the event set actions for
-            tuples: list of tuples of (user_id, actions)
-        """
-        values = []
-        for uid, actions in tuples:
-            is_highlight = 1 if _action_has_highlight(actions) else 0
-
-            values.append({
-                'room_id': event.room_id,
-                'event_id': event.event_id,
-                'user_id': uid,
-                'actions': _serialize_action(actions, is_highlight),
-                'stream_ordering': event.internal_metadata.stream_ordering,
-                'topological_ordering': event.depth,
-                'notif': 1,
-                'highlight': is_highlight,
-            })
-
-        for uid, __ in tuples:
-            txn.call_after(
-                self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                (event.room_id, uid)
-            )
-        self._simple_insert_many_txn(txn, "event_push_actions", values)
+        self._rotate_delay = 3
+        self._rotate_count = 10000
 
     @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
@@ -130,7 +102,7 @@ class EventPushActionsStore(SQLBaseStore):
     def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
                                           last_read_event_id):
         sql = (
-            "SELECT stream_ordering, topological_ordering"
+            "SELECT stream_ordering"
             " FROM events"
             " WHERE room_id = ? AND event_id = ?"
         )
@@ -142,17 +114,12 @@ class EventPushActionsStore(SQLBaseStore):
             return {"notify_count": 0, "highlight_count": 0}
 
         stream_ordering = results[0][0]
-        topological_ordering = results[0][1]
 
         return self._get_unread_counts_by_pos_txn(
-            txn, room_id, user_id, topological_ordering, stream_ordering
+            txn, room_id, user_id, stream_ordering
         )
 
-    def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
-                                      stream_ordering):
-        token = RoomStreamToken(
-            topological_ordering, stream_ordering
-        )
+    def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
 
         # First get number of notifications.
         # We don't need to put a notif=1 clause as all rows always have
@@ -163,10 +130,10 @@ class EventPushActionsStore(SQLBaseStore):
             " WHERE"
             " user_id = ?"
             " AND room_id = ?"
-            " AND %s"
-        ) % (lower_bound(token, self.database_engine, inclusive=False),)
+            " AND stream_ordering > ?"
+        )
 
-        txn.execute(sql, (user_id, room_id))
+        txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
         notify_count = row[0] if row else 0
 
@@ -186,10 +153,10 @@ class EventPushActionsStore(SQLBaseStore):
             " highlight = 1"
             " AND user_id = ?"
             " AND room_id = ?"
-            " AND %s"
-        ) % (lower_bound(token, self.database_engine, inclusive=False),)
+            " AND stream_ordering > ?"
+        )
 
-        txn.execute(sql, (user_id, room_id))
+        txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
         highlight_count = row[0] if row else 0
 
@@ -240,7 +207,6 @@ class EventPushActionsStore(SQLBaseStore):
                 "   ep.highlight "
                 " FROM ("
                 "   SELECT room_id,"
-                "       MAX(topological_ordering) as topological_ordering,"
                 "       MAX(stream_ordering) as stream_ordering"
                 "   FROM events"
                 "   INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -250,13 +216,7 @@ class EventPushActionsStore(SQLBaseStore):
                 " event_push_actions AS ep"
                 " WHERE"
                 "   ep.room_id = rl.room_id"
-                "   AND ("
-                "       ep.topological_ordering > rl.topological_ordering"
-                "       OR ("
-                "           ep.topological_ordering = rl.topological_ordering"
-                "           AND ep.stream_ordering > rl.stream_ordering"
-                "       )"
-                "   )"
+                "   AND ep.stream_ordering > rl.stream_ordering"
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
@@ -349,7 +309,6 @@ class EventPushActionsStore(SQLBaseStore):
                 "  ep.highlight, e.received_ts"
                 " FROM ("
                 "   SELECT room_id,"
-                "       MAX(topological_ordering) as topological_ordering,"
                 "       MAX(stream_ordering) as stream_ordering"
                 "   FROM events"
                 "   INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -360,13 +319,7 @@ class EventPushActionsStore(SQLBaseStore):
                 " INNER JOIN events AS e USING (room_id, event_id)"
                 " WHERE"
                 "   ep.room_id = rl.room_id"
-                "   AND ("
-                "       ep.topological_ordering > rl.topological_ordering"
-                "       OR ("
-                "           ep.topological_ordering = rl.topological_ordering"
-                "           AND ep.stream_ordering > rl.stream_ordering"
-                "       )"
-                "   )"
+                "   AND ep.stream_ordering > rl.stream_ordering"
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
@@ -432,6 +385,290 @@ class EventPushActionsStore(SQLBaseStore):
         # Now return the first `limit`
         defer.returnValue(notifs[:limit])
 
+    def add_push_actions_to_staging(self, event_id, user_id_actions):
+        """Add the push actions for the event to the push action staging area.
+
+        Args:
+            event_id (str)
+            user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
+                user_id to list of push actions, where an action can either be
+                a string or dict.
+
+        Returns:
+            Deferred
+        """
+
+        if not user_id_actions:
+            return
+
+        # This is a helper function for generating the necessary tuple that
+        # can be used to inert into the `event_push_actions_staging` table.
+        def _gen_entry(user_id, actions):
+            is_highlight = 1 if _action_has_highlight(actions) else 0
+            return (
+                event_id,  # event_id column
+                user_id,  # user_id column
+                _serialize_action(actions, is_highlight),  # actions column
+                1,  # notif column
+                is_highlight,  # highlight column
+            )
+
+        def _add_push_actions_to_staging_txn(txn):
+            # We don't use _simple_insert_many here to avoid the overhead
+            # of generating lists of dicts.
+
+            sql = """
+                INSERT INTO event_push_actions_staging
+                    (event_id, user_id, actions, notif, highlight)
+                VALUES (?, ?, ?, ?, ?)
+            """
+
+            txn.executemany(sql, (
+                _gen_entry(user_id, actions)
+                for user_id, actions in iteritems(user_id_actions)
+            ))
+
+        return self.runInteraction(
+            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        )
+
+    @defer.inlineCallbacks
+    def remove_push_actions_from_staging(self, event_id):
+        """Called if we failed to persist the event to ensure that stale push
+        actions don't build up in the DB
+
+        Args:
+            event_id (str)
+        """
+
+        try:
+            res = yield self._simple_delete(
+                table="event_push_actions_staging",
+                keyvalues={
+                    "event_id": event_id,
+                },
+                desc="remove_push_actions_from_staging",
+            )
+            defer.returnValue(res)
+        except Exception:
+            # this method is called from an exception handler, so propagating
+            # another exception here really isn't helpful - there's nothing
+            # the caller can do about it. Just log the exception and move on.
+            logger.exception(
+                "Error removing push actions after event persistence failure",
+            )
+
+    @defer.inlineCallbacks
+    def _find_stream_orderings_for_times(self):
+        yield self.runInteraction(
+            "_find_stream_orderings_for_times",
+            self._find_stream_orderings_for_times_txn
+        )
+
+    def _find_stream_orderings_for_times_txn(self, txn):
+        logger.info("Searching for stream ordering 1 month ago")
+        self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 month ago: it's %d",
+            self.stream_ordering_month_ago
+        )
+        logger.info("Searching for stream ordering 1 day ago")
+        self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 day ago: it's %d",
+            self.stream_ordering_day_ago
+        )
+
+    def find_first_stream_ordering_after_ts(self, ts):
+        """Gets the stream ordering corresponding to a given timestamp.
+
+        Specifically, finds the stream_ordering of the first event that was
+        received on or after the timestamp. This is done by a binary search on
+        the events table, since there is no index on received_ts, so is
+        relatively slow.
+
+        Args:
+            ts (int): timestamp in millis
+
+        Returns:
+            Deferred[int]: stream ordering of the first event received on/after
+                the timestamp
+        """
+        return self.runInteraction(
+            "_find_first_stream_ordering_after_ts_txn",
+            self._find_first_stream_ordering_after_ts_txn,
+            ts,
+        )
+
+    @staticmethod
+    def _find_first_stream_ordering_after_ts_txn(txn, ts):
+        """
+        Find the stream_ordering of the first event that was received on or
+        after a given timestamp. This is relatively slow as there is no index
+        on received_ts but we can then use this to delete push actions before
+        this.
+
+        received_ts must necessarily be in the same order as stream_ordering
+        and stream_ordering is indexed, so we manually binary search using
+        stream_ordering
+
+        Args:
+            txn (twisted.enterprise.adbapi.Transaction):
+            ts (int): timestamp to search for
+
+        Returns:
+            int: stream ordering
+        """
+        txn.execute("SELECT MAX(stream_ordering) FROM events")
+        max_stream_ordering = txn.fetchone()[0]
+
+        if max_stream_ordering is None:
+            return 0
+
+        # We want the first stream_ordering in which received_ts is greater
+        # than or equal to ts. Call this point X.
+        #
+        # We maintain the invariants:
+        #
+        #   range_start <= X <= range_end
+        #
+        range_start = 0
+        range_end = max_stream_ordering + 1
+
+        # Given a stream_ordering, look up the timestamp at that
+        # stream_ordering.
+        #
+        # The array may be sparse (we may be missing some stream_orderings).
+        # We treat the gaps as the same as having the same value as the
+        # preceding entry, because we will pick the lowest stream_ordering
+        # which satisfies our requirement of received_ts >= ts.
+        #
+        # For example, if our array of events indexed by stream_ordering is
+        # [10, <none>, 20], we should treat this as being equivalent to
+        # [10, 10, 20].
+        #
+        sql = (
+            "SELECT received_ts FROM events"
+            " WHERE stream_ordering <= ?"
+            " ORDER BY stream_ordering DESC"
+            " LIMIT 1"
+        )
+
+        while range_end - range_start > 0:
+            middle = (range_end + range_start) // 2
+            txn.execute(sql, (middle,))
+            row = txn.fetchone()
+            if row is None:
+                # no rows with stream_ordering<=middle
+                range_start = middle + 1
+                continue
+
+            middle_ts = row[0]
+            if ts > middle_ts:
+                # we got a timestamp lower than the one we were looking for.
+                # definitely need to look higher: X > middle.
+                range_start = middle + 1
+            else:
+                # we got a timestamp higher than (or the same as) the one we
+                # were looking for. We aren't yet sure about the point we
+                # looked up, but we can be sure that X <= middle.
+                range_end = middle
+
+        return range_end
+
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1"
+        )
+
+        self._doing_notif_rotation = False
+        self._rotate_notif_loop = self._clock.looping_call(
+            self._rotate_notifs, 30 * 60 * 1000
+        )
+
+    def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
+                                                  all_events_and_contexts):
+        """Handles moving push actions from staging table to main
+        event_push_actions table for all events in `events_and_contexts`.
+
+        Also ensures that all events in `all_events_and_contexts` are removed
+        from the push action staging area.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+        """
+
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
+            )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
+
+        if events_and_contexts:
+            txn.executemany(sql, (
+                (
+                    event.room_id, event.internal_metadata.stream_ordering,
+                    event.depth, event.event_id,
+                )
+                for event, _ in events_and_contexts
+            ))
+
+        for event, _ in events_and_contexts:
+            user_ids = self._simple_select_onecol_txn(
+                txn,
+                table="event_push_actions_staging",
+                keyvalues={
+                    "event_id": event.event_id,
+                },
+                retcol="user_id",
+            )
+
+            for uid in user_ids:
+                txn.call_after(
+                    self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                    (event.room_id, uid,)
+                )
+
+        # Now we delete the staging area for *all* events that were being
+        # persisted.
+        txn.executemany(
+            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+            (
+                (event.event_id,)
+                for event, _ in all_events_and_contexts
+            )
+        )
+
     @defer.inlineCallbacks
     def get_push_actions_for_user(self, user_id, before=None, limit=50,
                                   only_highlight=False):
@@ -509,10 +746,10 @@ class EventPushActionsStore(SQLBaseStore):
         )
 
     def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
-                                            topological_ordering, stream_ordering):
+                                            stream_ordering):
         """
         Purges old push actions for a user and room before a given
-        topological_ordering.
+        stream_ordering.
 
         We however keep a months worth of highlighted notifications, so that
         users can still get a list of recent highlights.
@@ -521,7 +758,7 @@ class EventPushActionsStore(SQLBaseStore):
             txn: The transcation
             room_id: Room ID to delete from
             user_id: user ID to delete for
-            topological_ordering: The lowest topological ordering which will
+            stream_ordering: The lowest stream ordering which will
                                   not be deleted.
         """
         txn.call_after(
@@ -540,9 +777,9 @@ class EventPushActionsStore(SQLBaseStore):
         txn.execute(
             "DELETE FROM event_push_actions "
             " WHERE user_id = ? AND room_id = ? AND "
-            " topological_ordering <= ?"
+            " stream_ordering <= ?"
             " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
         )
 
         txn.execute("""
@@ -551,69 +788,6 @@ class EventPushActionsStore(SQLBaseStore):
         """, (room_id, user_id, stream_ordering))
 
     @defer.inlineCallbacks
-    def _find_stream_orderings_for_times(self):
-        yield self.runInteraction(
-            "_find_stream_orderings_for_times",
-            self._find_stream_orderings_for_times_txn
-        )
-
-    def _find_stream_orderings_for_times_txn(self, txn):
-        logger.info("Searching for stream ordering 1 month ago")
-        self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
-            txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
-        )
-        logger.info(
-            "Found stream ordering 1 month ago: it's %d",
-            self.stream_ordering_month_ago
-        )
-        logger.info("Searching for stream ordering 1 day ago")
-        self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
-            txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
-        )
-        logger.info(
-            "Found stream ordering 1 day ago: it's %d",
-            self.stream_ordering_day_ago
-        )
-
-    def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
-        """
-        Find the stream_ordering of the first event that was received after
-        a given timestamp. This is relatively slow as there is no index on
-        received_ts but we can then use this to delete push actions before
-        this.
-
-        received_ts must necessarily be in the same order as stream_ordering
-        and stream_ordering is indexed, so we manually binary search using
-        stream_ordering
-        """
-        txn.execute("SELECT MAX(stream_ordering) FROM events")
-        max_stream_ordering = txn.fetchone()[0]
-
-        if max_stream_ordering is None:
-            return 0
-
-        range_start = 0
-        range_end = max_stream_ordering
-
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering > ?"
-            " ORDER BY stream_ordering"
-            " LIMIT 1"
-        )
-
-        while range_end - range_start > 1:
-            middle = int((range_end + range_start) / 2)
-            txn.execute(sql, (middle,))
-            middle_ts = txn.fetchone()[0]
-            if ts > middle_ts:
-                range_start = middle
-            else:
-                range_end = middle
-
-        return range_end
-
-    @defer.inlineCallbacks
     def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
@@ -629,7 +803,7 @@ class EventPushActionsStore(SQLBaseStore):
                 )
                 if caught_up:
                     break
-                yield sleep(5)
+                yield self.hs.get_clock().sleep(self._rotate_delay)
         finally:
             self._doing_notif_rotation = False
 
@@ -650,8 +824,8 @@ class EventPushActionsStore(SQLBaseStore):
         txn.execute("""
             SELECT stream_ordering FROM event_push_actions
             WHERE stream_ordering > ?
-            ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
-        """, (old_rotate_stream_ordering,))
+            ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+        """, (old_rotate_stream_ordering, self._rotate_count))
         stream_row = txn.fetchone()
         if stream_row:
             offset_stream_ordering, = stream_row
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7002b3752e..906a405031 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,64 +13,59 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore
 
-from twisted.internet import defer, reactor
+import itertools
+import logging
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
 
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+from six import iteritems
+from six.moves import range
 
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import (
-    preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
-)
-from synapse.util.logutils import log_function
-from synapse.util.metrics import Measure
-from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
-from synapse.util.caches.descriptors import cached
-from synapse.types import get_domain_from_id
+from canonicaljson import json
+from prometheus_client import Counter
 
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
+from twisted.internet import defer
 
 import synapse.metrics
-
-import logging
-import ujson as json
-
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
 # these are only included to make the type annotations work
-from synapse.events import EventBase    # noqa: F401
-from synapse.events.snapshot import EventContext   # noqa: F401
+from synapse.events import EventBase  # noqa: F401
+from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util.async import ObservableDeferred
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
+from synapse.util.logutils import log_function
+from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
 
+persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
+event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
+                        ["type", "origin_type", "origin_entity"])
 
-metrics = synapse.metrics.get_metrics_for(__name__)
-persist_event_counter = metrics.register_counter("persisted_events")
-event_counter = metrics.register_counter(
-    "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
-)
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
 
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+    "synapse_storage_events_state_delta_single_event", "")
 
-def encode_json(json_object):
-    if USE_FROZEN_DICTS:
-        # ujson doesn't like frozen_dicts
-        return encode_canonical_json(json_object)
-    else:
-        return json.dumps(json_object, ensure_ascii=False)
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+    "synapse_storage_events_state_delta_reuse_delta", "")
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+def encode_json(json_object):
+    return frozendict_json_encoder.encode(json_object)
 
 
 class _EventPeristenceQueue(object):
@@ -88,19 +84,29 @@ class _EventPeristenceQueue(object):
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
 
+        NB: due to the normal usage pattern of this method, it does *not*
+        follow the synapse logcontext rules, and leaves the logcontext in
+        place whether or not the returned deferred is ready.
+
         Args:
             room_id (str):
             events_and_contexts (list[(EventBase, EventContext)]):
             backfilled (bool):
+
+        Returns:
+            defer.Deferred: a deferred which will resolve once the events are
+                persisted. Runs its callbacks *without* a logcontext.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
+            # if the last item in the queue has the same `backfilled` setting,
+            # we can just add these new events to that item.
             end_item = queue[-1]
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
 
-        deferred = ObservableDeferred(defer.Deferred())
+        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
 
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
@@ -113,11 +119,11 @@ class _EventPeristenceQueue(object):
     def handle_queue(self, room_id, per_item_callback):
         """Attempts to handle the queue for a room if not already being handled.
 
-        The given callback will be invoked with for each item in the queue,1
+        The given callback will be invoked with for each item in the queue,
         of type _EventPersistQueueItem. The per_item_callback will continuously
         be called with new items, unless the queue becomnes empty. The return
         value of the function will be given to the deferreds waiting on the item,
-        exceptions will be passed to the deferres as well.
+        exceptions will be passed to the deferreds as well.
 
         This function should therefore be called whenever anything is added
         to the queue.
@@ -136,18 +142,23 @@ class _EventPeristenceQueue(object):
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
+                    # handle_queue_loop runs in the sentinel logcontext, so
+                    # there is no need to preserve_fn when running the
+                    # callbacks on the deferred.
                     try:
                         ret = yield per_item_callback(item)
-                        item.deferred.callback(ret)
-                    except Exception as e:
-                        item.deferred.errback(e)
+                        with PreserveLoggingContext():
+                            item.deferred.callback(ret)
+                    except Exception:
+                        item.deferred.errback()
             finally:
                 queue = self._event_persist_queues.pop(room_id, None)
                 if queue:
                     self._event_persist_queues[room_id] = queue
                 self._currently_persisting_rooms.discard(room_id)
 
-        preserve_fn(handle_queue_loop)()
+        # set handle_queue_loop off in the background
+        run_as_background_process("persist_events", handle_queue_loop)
 
     def _get_drainining_queue(self, room_id):
         queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -183,13 +194,12 @@ def _retry_on_integrity_error(func):
     return f
 
 
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
-    def __init__(self, hs):
-        super(EventsStore, self).__init__(hs)
-        self._clock = hs.get_clock()
+    def __init__(self, db_conn, hs):
+        super(EventsStore, self).__init__(db_conn, hs)
         self.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
@@ -220,6 +230,8 @@ class EventsStore(SQLBaseStore):
 
         self._event_persist_queue = _EventPeristenceQueue()
 
+        self._state_resolution_handler = hs.get_state_resolution_handler()
+
     def persist_events(self, events_and_contexts, backfilled=False):
         """
         Write events to the database
@@ -232,8 +244,8 @@ class EventsStore(SQLBaseStore):
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
         deferreds = []
-        for room_id, evs_ctxs in partitioned.iteritems():
-            d = preserve_fn(self._event_persist_queue.add_to_queue)(
+        for room_id, evs_ctxs in iteritems(partitioned):
+            d = self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
@@ -242,7 +254,7 @@ class EventsStore(SQLBaseStore):
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        return preserve_context_over_deferred(
+        return make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
@@ -267,7 +279,7 @@ class EventsStore(SQLBaseStore):
 
         self._maybe_start_persisting(event.room_id)
 
-        yield preserve_context_over_deferred(deferred)
+        yield make_deferred_yieldable(deferred)
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
         defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -275,10 +287,11 @@ class EventsStore(SQLBaseStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            yield self._persist_events(
-                item.events_and_contexts,
-                backfilled=item.backfilled,
-            )
+            with Measure(self._clock, "persist_events"):
+                yield self._persist_events(
+                    item.events_and_contexts,
+                    backfilled=item.backfilled,
+                )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
@@ -316,7 +329,7 @@ class EventsStore(SQLBaseStore):
 
             chunks = [
                 events_and_contexts[x:x + 100]
-                for x in xrange(0, len(events_and_contexts), 100)
+                for x in range(0, len(events_and_contexts), 100)
             ]
 
             for chunk in chunks:
@@ -325,8 +338,23 @@ class EventsStore(SQLBaseStore):
 
                 # NB: Assumes that we are only persisting events for one room
                 # at a time.
+
+                # map room_id->list[event_ids] giving the new forward
+                # extremities in each room
                 new_forward_extremeties = {}
+
+                # map room_id->(type,state_key)->event_id tracking the full
+                # state in each room after adding these events.
+                # This is simply used to prefill the get_current_state_ids
+                # cache
                 current_state_for_room = {}
+
+                # map room_id->(to_delete, to_insert) where to_delete is a list
+                # of type/state keys to remove from current state, and to_insert
+                # is a map (type,key)->event_id giving the state delta in each
+                # room
+                state_delta_for_room = {}
+
                 if not backfilled:
                     with Measure(self._clock, "_calculate_state_and_extrem"):
                         # Work out the new "current state" for each room.
@@ -338,7 +366,7 @@ class EventsStore(SQLBaseStore):
                                 (event, context)
                             )
 
-                        for room_id, ev_ctx_rm in events_by_room.iteritems():
+                        for room_id, ev_ctx_rm in iteritems(events_by_room):
                             # Work out new extremities by recursively adding and removing
                             # the new events.
                             latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -348,7 +376,8 @@ class EventsStore(SQLBaseStore):
                                 room_id, ev_ctx_rm, latest_event_ids
                             )
 
-                            if new_latest_event_ids == set(latest_event_ids):
+                            latest_event_ids = set(latest_event_ids)
+                            if new_latest_event_ids == latest_event_ids:
                                 # No change in extremities, so no change in state
                                 continue
 
@@ -369,11 +398,63 @@ class EventsStore(SQLBaseStore):
                                 if all_single_prev_not_state:
                                     continue
 
-                            state = yield self._calculate_state_delta(
-                                room_id, ev_ctx_rm, new_latest_event_ids
+                            state_delta_counter.inc()
+                            if len(new_latest_event_ids) == 1:
+                                state_delta_single_event_counter.inc()
+
+                                # This is a fairly handwavey check to see if we could
+                                # have guessed what the delta would have been when
+                                # processing one of these events.
+                                # What we're interested in is if the latest extremities
+                                # were the same when we created the event as they are
+                                # now. When this server creates a new event (as opposed
+                                # to receiving it over federation) it will use the
+                                # forward extremities as the prev_events, so we can
+                                # guess this by looking at the prev_events and checking
+                                # if they match the current forward extremities.
+                                for ev, _ in ev_ctx_rm:
+                                    prev_event_ids = set(e for e, _ in ev.prev_events)
+                                    if latest_event_ids == prev_event_ids:
+                                        state_delta_reuse_delta_counter.inc()
+                                        break
+
+                            logger.info(
+                                "Calculating state delta for room %s", room_id,
                             )
-                            if state:
-                                current_state_for_room[room_id] = state
+                            with Measure(
+                                self._clock,
+                                "persist_events.get_new_state_after_events",
+                            ):
+                                res = yield self._get_new_state_after_events(
+                                    room_id,
+                                    ev_ctx_rm,
+                                    latest_event_ids,
+                                    new_latest_event_ids,
+                                )
+                                current_state, delta_ids = res
+
+                            # If either are not None then there has been a change,
+                            # and we need to work out the delta (or use that
+                            # given)
+                            if delta_ids is not None:
+                                # If there is a delta we know that we've
+                                # only added or replaced state, never
+                                # removed keys entirely.
+                                state_delta_for_room[room_id] = ([], delta_ids)
+                            elif current_state is not None:
+                                with Measure(
+                                    self._clock,
+                                    "persist_events.calculate_state_delta",
+                                ):
+                                    delta = yield self._calculate_state_delta(
+                                        room_id, current_state,
+                                    )
+                                state_delta_for_room[room_id] = delta
+
+                            # If we have the current_state then lets prefill
+                            # the cache with it.
+                            if current_state is not None:
+                                current_state_for_room[room_id] = current_state
 
                 yield self.runInteraction(
                     "persist_events",
@@ -381,10 +462,13 @@ class EventsStore(SQLBaseStore):
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
-                    current_state_for_room=current_state_for_room,
+                    state_delta_for_room=state_delta_for_room,
                     new_forward_extremeties=new_forward_extremeties,
                 )
-                persist_event_counter.inc_by(len(chunk))
+                persist_event_counter.inc(len(chunk))
+                synapse.metrics.event_persisted_position.set(
+                    chunk[-1][0].internal_metadata.stream_ordering,
+                )
                 for event, context in chunk:
                     if context.app_service:
                         origin_type = "local"
@@ -396,14 +480,14 @@ class EventsStore(SQLBaseStore):
                         origin_type = "remote"
                         origin_entity = get_domain_from_id(event.sender)
 
-                    event_counter.inc(event.type, origin_type, origin_entity)
+                    event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-                for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+                for room_id, new_state in iteritems(current_state_for_room):
                     self.get_current_state_ids.prefill(
                         (room_id, ), new_state
                     )
 
-                for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+                for room_id, latest_event_ids in iteritems(new_forward_extremeties):
                     self.get_latest_event_ids_in_room.prefill(
                         (room_id,), list(latest_event_ids)
                     )
@@ -450,183 +534,187 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
-        """Calculate the new state deltas for a room.
+    def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+                                    new_latest_event_ids):
+        """Calculate the current state dict after adding some new events to
+        a room
 
-        Assumes that we are only persisting events for one room at a time.
+        Args:
+            room_id (str):
+                room to which the events are being added. Used for logging etc
+
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            old_latest_event_ids (iterable[str]):
+                the old forward extremities for the room.
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
 
         Returns:
-            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
-            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
-            first be deleted from current_state_events, `to_insert` are entries
-            to insert. `new_state` is the full set of state.
-            May return None if there are no changes to be applied.
+            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+            Returns a tuple of two state maps, the first being the full new current
+            state and the second being the delta to the existing current state.
+            If both are None then there has been no change.
+
+            If there has been a change then we only return the delta if its
+            already been calculated. Conversely if we do know the delta then
+            the new current state is only returned if we've already calculated
+            it.
         """
-        # Now we need to work out the different state sets for
-        # each state extremities
-        state_sets = []
-        state_groups = set()
-        missing_event_ids = []
-        was_updated = False
+
+        if not new_latest_event_ids:
+            return
+
+        # map from state_group to ((type, key) -> event_id) state map
+        state_groups_map = {}
+
+        # Map from (prev state group, new state group) -> delta state dict
+        state_group_deltas = {}
+
+        for ev, ctx in events_context:
+            if ctx.state_group is None:
+                # I don't think this can happen, but let's double-check
+                raise Exception(
+                    "Context for new extremity event %s has no state "
+                    "group" % (ev.event_id, ),
+                )
+
+            if ctx.state_group in state_groups_map:
+                continue
+
+            # We're only interested in pulling out state that has already
+            # been cached in the context. We'll pull stuff out of the DB later
+            # if necessary.
+            current_state_ids = ctx.get_cached_current_state_ids()
+            if current_state_ids is not None:
+                state_groups_map[ctx.state_group] = current_state_ids
+
+            if ctx.prev_group:
+                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+        # We need to map the event_ids to their state groups. First, let's
+        # check if the event is one we're persisting, in which case we can
+        # pull the state group from its context.
+        # Otherwise we need to pull the state group from the database.
+
+        # Set of events we need to fetch groups for. (We know none of the old
+        # extremities are going to be in events_context).
+        missing_event_ids = set(old_latest_event_ids)
+
+        event_id_to_state_group = {}
         for event_id in new_latest_event_ids:
-            # First search in the list of new events we're adding,
-            # and then use the current state from that
+            # First search in the list of new events we're adding.
             for ev, ctx in events_context:
                 if event_id == ev.event_id:
-                    if ctx.current_state_ids is None:
-                        raise Exception("Unknown current state")
-
-                    # If we've already seen the state group don't bother adding
-                    # it to the state sets again
-                    if ctx.state_group not in state_groups:
-                        state_sets.append(ctx.current_state_ids)
-                        if ctx.delta_ids or hasattr(ev, "state_key"):
-                            was_updated = True
-                        if ctx.state_group:
-                            # Add this as a seen state group (if it has a state
-                            # group)
-                            state_groups.add(ctx.state_group)
+                    event_id_to_state_group[event_id] = ctx.state_group
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
                 # the state from the database
-                was_updated = True
-                missing_event_ids.append(event_id)
+                missing_event_ids.add(event_id)
 
         if missing_event_ids:
-            # Now pull out the state for any missing events from DB
+            # Now pull out the state groups for any missing events from DB
             event_to_groups = yield self._get_state_group_for_events(
                 missing_event_ids,
             )
+            event_id_to_state_group.update(event_to_groups)
 
-            groups = set(event_to_groups.itervalues()) - state_groups
+        # State groups of old_latest_event_ids
+        old_state_groups = set(
+            event_id_to_state_group[evid] for evid in old_latest_event_ids
+        )
 
-            if groups:
-                group_to_state = yield self._get_state_for_groups(groups)
-                state_sets.extend(group_to_state.itervalues())
+        # State groups of new_latest_event_ids
+        new_state_groups = set(
+            event_id_to_state_group[evid] for evid in new_latest_event_ids
+        )
 
-        if not new_latest_event_ids:
-            current_state = {}
-        elif was_updated:
-            if len(state_sets) == 1:
-                # If there is only one state set, then we know what the current
-                # state is.
-                current_state = state_sets[0]
-            else:
-                # We work out the current state by passing the state sets to the
-                # state resolution algorithm. It may ask for some events, including
-                # the events we have yet to persist, so we need a slightly more
-                # complicated event lookup function than simply looking the events
-                # up in the db.
-                events_map = {ev.event_id: ev for ev, _ in events_context}
-
-                @defer.inlineCallbacks
-                def get_events(ev_ids):
-                    # We get the events by first looking at the list of events we
-                    # are trying to persist, and then fetching the rest from the DB.
-                    db = []
-                    to_return = {}
-                    for ev_id in ev_ids:
-                        ev = events_map.get(ev_id, None)
-                        if ev:
-                            to_return[ev_id] = ev
-                        else:
-                            db.append(ev_id)
-
-                    if db:
-                        evs = yield self.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False,
-                        )
-                        to_return.update(evs)
-                    defer.returnValue(to_return)
-
-                current_state = yield resolve_events(
-                    state_sets,
-                    state_map_factory=get_events,
-                )
-        else:
-            return
+        # If they old and new groups are the same then we don't need to do
+        # anything.
+        if old_state_groups == new_state_groups:
+            defer.returnValue((None, None))
 
-        existing_state = yield self.get_current_state_ids(room_id)
+        if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+            # If we're going from one state group to another, lets check if
+            # we have a delta for that transition. If we do then we can just
+            # return that.
 
-        existing_events = set(existing_state.itervalues())
-        new_events = set(ev_id for ev_id in current_state.itervalues())
-        changed_events = existing_events ^ new_events
+            new_state_group = next(iter(new_state_groups))
+            old_state_group = next(iter(old_state_groups))
 
-        if not changed_events:
-            return
+            delta_ids = state_group_deltas.get(
+                (old_state_group, new_state_group,), None
+            )
+            if delta_ids is not None:
+                # We have a delta from the existing to new current state,
+                # so lets just return that. If we happen to already have
+                # the current state in memory then lets also return that,
+                # but it doesn't matter if we don't.
+                new_state = state_groups_map.get(new_state_group)
+                defer.returnValue((new_state, delta_ids))
+
+        # Now that we have calculated new_state_groups we need to get
+        # their state IDs so we can resolve to a single state set.
+        missing_state = new_state_groups - set(state_groups_map)
+        if missing_state:
+            group_to_state = yield self._get_state_for_groups(missing_state)
+            state_groups_map.update(group_to_state)
+
+        if len(new_state_groups) == 1:
+            # If there is only one state group, then we know what the current
+            # state is.
+            defer.returnValue((state_groups_map[new_state_groups.pop()], None))
+
+        # Ok, we need to defer to the state handler to resolve our state sets.
+
+        def get_events(ev_ids):
+            return self.get_events(
+                ev_ids, get_prev_content=False, check_redacted=False,
+            )
 
-        to_delete = {
-            key: ev_id for key, ev_id in existing_state.iteritems()
-            if ev_id in changed_events
+        state_groups = {
+            sg: state_groups_map[sg] for sg in new_state_groups
         }
-        events_to_insert = (new_events - existing_events)
-        to_insert = {
-            key: ev_id for key, ev_id in current_state.iteritems()
-            if ev_id in events_to_insert
-        }
-
-        defer.returnValue((to_delete, to_insert, current_state))
-
-    @defer.inlineCallbacks
-    def get_event(self, event_id, check_redacted=True,
-                  get_prev_content=False, allow_rejected=False,
-                  allow_none=False):
-        """Get an event from the database by event_id.
-
-        Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
-                False throw an exception.
 
-        Returns:
-            Deferred : A FrozenEvent.
-        """
-        events = yield self._get_events(
-            [event_id],
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
+        events_map = {ev.event_id: ev for ev, _ in events_context}
+        logger.debug("calling resolve_state_groups from preserve_events")
+        res = yield self._state_resolution_handler.resolve_state_groups(
+            room_id, state_groups, events_map, get_events
         )
 
-        if not events and not allow_none:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
-
-        defer.returnValue(events[0] if events else None)
+        defer.returnValue((res.state, None))
 
     @defer.inlineCallbacks
-    def get_events(self, event_ids, check_redacted=True,
-                   get_prev_content=False, allow_rejected=False):
-        """Get events from the database
+    def _calculate_state_delta(self, room_id, current_state):
+        """Calculate the new state deltas for a room.
 
-        Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+        Assumes that we are only persisting events for one room at a time.
 
         Returns:
-            Deferred : Dict from event_id to event.
+            tuple[list, dict] (to_delete, to_insert): where to_delete are the
+            type/state_keys to remove from current_state_events and `to_insert`
+            are the updates to current_state_events.
         """
-        events = yield self._get_events(
-            event_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
+        existing_state = yield self.get_current_state_ids(room_id)
+
+        to_delete = [
+            key for key in existing_state
+            if key not in current_state
+        ]
+
+        to_insert = {
+            key: ev_id for key, ev_id in iteritems(current_state)
+            if ev_id != existing_state.get(key)
+        }
 
-        defer.returnValue({e.event_id: e for e in events})
+        defer.returnValue((to_delete, to_insert))
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, current_state_for_room={},
+                            delete_existing=False, state_delta_for_room={},
                             new_forward_extremeties={}):
         """Insert some number of room events into the necessary database tables.
 
@@ -642,19 +730,21 @@ class EventsStore(SQLBaseStore):
             delete_existing (bool): True to purge existing table rows for the
                 events from the database. This is useful when retrying due to
                 IntegrityError.
-            current_state_for_room (dict[str, (list[str], list[str])]):
+            state_delta_for_room (dict[str, (list, dict)]):
                 The current-state delta for each room. For each room, a tuple
-                (to_delete, to_insert), being a list of event ids to be removed
-                from the current state, and a list of event ids to be added to
+                (to_delete, to_insert), being a list of type/state keys to be
+                removed from the current state, and a state set to be added to
                 the current state.
             new_forward_extremeties (dict[str, list[str]]):
                 The new forward extremities for each room. For each room, a
                 list of the event ids which are the forward extremities.
 
         """
+        all_events_and_contexts = events_and_contexts
+
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+        self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
 
         self._update_forward_extremities_txn(
             txn,
@@ -698,9 +788,8 @@ class EventsStore(SQLBaseStore):
             events_and_contexts=events_and_contexts,
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
+        # Insert into event_to_state_groups.
+        self._store_event_state_mappings_txn(txn, events_and_contexts)
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -715,15 +804,53 @@ class EventsStore(SQLBaseStore):
         self._update_metadata_tables_txn(
             txn,
             events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
             backfilled=backfilled,
         )
 
     def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
-        for room_id, current_state_tuple in state_delta_by_room.iteritems():
-                to_delete, to_insert, _ = current_state_tuple
+        for room_id, current_state_tuple in iteritems(state_delta_by_room):
+                to_delete, to_insert = current_state_tuple
+
+                # First we add entries to the current_state_delta_stream. We
+                # do this before updating the current_state_events table so
+                # that we can use it to calculate the `prev_event_id`. (This
+                # allows us to not have to pull out the existing state
+                # unnecessarily).
+                sql = """
+                    INSERT INTO current_state_delta_stream
+                    (stream_id, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, ?, ?, ?, (
+                        SELECT event_id FROM current_state_events
+                        WHERE room_id = ? AND type = ? AND state_key = ?
+                    )
+                """
+                txn.executemany(sql, (
+                    (
+                        max_stream_order, room_id, etype, state_key, None,
+                        room_id, etype, state_key,
+                    )
+                    for etype, state_key in to_delete
+                    # We sanity check that we're deleting rather than updating
+                    if (etype, state_key) not in to_insert
+                ))
+                txn.executemany(sql, (
+                    (
+                        max_stream_order, room_id, etype, state_key, ev_id,
+                        room_id, etype, state_key,
+                    )
+                    for (etype, state_key), ev_id in iteritems(to_insert)
+                ))
+
+                # Now we actually update the current_state_events table
+
                 txn.executemany(
-                    "DELETE FROM current_state_events WHERE event_id = ?",
-                    [(ev_id,) for ev_id in to_delete.itervalues()],
+                    "DELETE FROM current_state_events"
+                    " WHERE room_id = ? AND type = ? AND state_key = ?",
+                    (
+                        (room_id, etype, state_key)
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                    ),
                 )
 
                 self._simple_insert_many_txn(
@@ -736,30 +863,12 @@ class EventsStore(SQLBaseStore):
                             "type": key[0],
                             "state_key": key[1],
                         }
-                        for key, ev_id in to_insert.iteritems()
+                        for key, ev_id in iteritems(to_insert)
                     ],
                 )
 
-                state_deltas = {key: None for key in to_delete}
-                state_deltas.update(to_insert)
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="current_state_delta_stream",
-                    values=[
-                        {
-                            "stream_id": max_stream_order,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": ev_id,
-                            "prev_event_id": to_delete.get(key, None),
-                        }
-                        for key, ev_id in state_deltas.iteritems()
-                    ]
-                )
-
-                self._curr_state_delta_stream_cache.entity_has_changed(
+                txn.call_after(
+                    self._curr_state_delta_stream_cache.entity_has_changed,
                     room_id, max_stream_order,
                 )
 
@@ -771,19 +880,23 @@ class EventsStore(SQLBaseStore):
                 # and which we have added, then we invlidate the caches for all
                 # those users.
                 members_changed = set(
-                    state_key for ev_type, state_key in state_deltas
+                    state_key
+                    for ev_type, state_key in itertools.chain(to_delete, to_insert)
                     if ev_type == EventTypes.Member
                 )
 
                 for member in members_changed:
                     self._invalidate_cache_and_stream(
-                        txn, self.get_rooms_for_user, (member,)
+                        txn, self.get_rooms_for_user_with_stream_ordering, (member,)
                     )
 
                 for host in set(get_domain_from_id(u) for u in members_changed):
                     self._invalidate_cache_and_stream(
                         txn, self.is_host_joined, (room_id, host)
                     )
+                    self._invalidate_cache_and_stream(
+                        txn, self.was_host_joined, (room_id, host)
+                    )
 
                 self._invalidate_cache_and_stream(
                     txn, self.get_users_in_room, (room_id,)
@@ -795,7 +908,7 @@ class EventsStore(SQLBaseStore):
 
     def _update_forward_extremities_txn(self, txn, new_forward_extremities,
                                         max_stream_order):
-        for room_id, new_extrem in new_forward_extremities.iteritems():
+        for room_id, new_extrem in iteritems(new_forward_extremities):
             self._simple_delete_txn(
                 txn,
                 table="event_forward_extremities",
@@ -813,7 +926,7 @@ class EventsStore(SQLBaseStore):
                     "event_id": ev_id,
                     "room_id": room_id,
                 }
-                for room_id, new_extrem in new_forward_extremities.iteritems()
+                for room_id, new_extrem in iteritems(new_forward_extremities)
                 for ev_id in new_extrem
             ],
         )
@@ -830,7 +943,7 @@ class EventsStore(SQLBaseStore):
                     "event_id": event_id,
                     "stream_ordering": max_stream_order,
                 }
-                for room_id, new_extrem in new_forward_extremities.iteritems()
+                for room_id, new_extrem in iteritems(new_forward_extremities)
                 for event_id in new_extrem
             ]
         )
@@ -858,7 +971,7 @@ class EventsStore(SQLBaseStore):
                         new_events_and_contexts[event.event_id] = (event, context)
             else:
                 new_events_and_contexts[event.event_id] = (event, context)
-        return new_events_and_contexts.values()
+        return list(new_events_and_contexts.values())
 
     def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
         """Update min_depth for each room
@@ -884,7 +997,7 @@ class EventsStore(SQLBaseStore):
                     event.depth, depth_updates.get(event.room_id, event.depth)
                 )
 
-        for room_id, depth in depth_updates.iteritems():
+        for room_id, depth in iteritems(depth_updates):
             self._update_min_depth_for_room_txn(txn, room_id, depth)
 
     def _update_outliers_txn(self, txn, events_and_contexts):
@@ -932,10 +1045,9 @@ class EventsStore(SQLBaseStore):
                 # an outlier in the database. We now have some state at that
                 # so we need to update the state_groups table with that state.
 
-                # insert into the state_group, state_groups_state and
-                # event_to_state_groups tables.
+                # insert into event_to_state_groups.
                 try:
-                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                    self._store_event_state_mappings_txn(txn, ((event, context),))
                 except Exception:
                     logger.exception("")
                     raise
@@ -1001,7 +1113,6 @@ class EventsStore(SQLBaseStore):
                 "event_edge_hashes",
                 "event_edges",
                 "event_forward_extremities",
-                "event_push_actions",
                 "event_reference_hashes",
                 "event_search",
                 "event_signatures",
@@ -1021,6 +1132,14 @@ class EventsStore(SQLBaseStore):
                 [(ev.event_id,) for ev, _ in events_and_contexts]
             )
 
+        for table in (
+            "event_push_actions",
+        ):
+            txn.executemany(
+                "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
+                [(ev.event_id,) for ev, _ in events_and_contexts]
+            )
+
     def _store_event_txn(self, txn, events_and_contexts):
         """Insert new events into the event and event_json tables
 
@@ -1110,27 +1229,33 @@ class EventsStore(SQLBaseStore):
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
-    def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+    def _update_metadata_tables_txn(self, txn, events_and_contexts,
+                                    all_events_and_contexts, backfilled):
         """Update all the miscellaneous tables for new events
 
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
             events_and_contexts (list[(EventBase, EventContext)]): events
                 we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
             backfilled (bool): True if the events were backfilled
         """
 
+        # Insert all the push actions into the event_push_actions table.
+        self._set_push_actions_for_event_and_users_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+        )
+
         if not events_and_contexts:
             # nothing to do here
             return
 
         for event, context in events_and_contexts:
-            # Insert all the push actions into the event_push_actions table.
-            if context.push_actions:
-                self._set_push_actions_for_event_and_users_txn(
-                    txn, event, context.push_actions
-                )
-
             if event.type == EventTypes.Redaction and event.redacts is not None:
                 # Remove the entries in the event_push_actions table for the
                 # redacted event.
@@ -1263,7 +1388,7 @@ class EventsStore(SQLBaseStore):
                 " WHERE e.event_id IN (%s)"
             ) % (",".join(["?"] * len(ev_map)),)
 
-            txn.execute(sql, ev_map.keys())
+            txn.execute(sql, list(ev_map))
             rows = self.cursor_to_dict(txn)
             for row in rows:
                 event = ev_map[row["event_id"]]
@@ -1302,13 +1427,49 @@ class EventsStore(SQLBaseStore):
 
         defer.returnValue(set(r["event_id"] for r in rows))
 
-    def have_events(self, event_ids):
+    @defer.inlineCallbacks
+    def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
+        Args:
+            event_ids (iterable[str]):
+
         Returns:
-            dict: Has an entry for each event id we already have seen. Maps to
-            the rejected reason string if we rejected the event, else maps to
-            None.
+            Deferred[set[str]]: The events we have already seen.
+        """
+        results = set()
+
+        def have_seen_events_txn(txn, chunk):
+            sql = (
+                "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+                % (",".join("?" * len(chunk)), )
+            )
+            txn.execute(sql, chunk)
+            for (event_id, ) in txn:
+                results.add(event_id)
+
+        # break the input up into chunks of 100
+        input_iterator = iter(event_ids)
+        for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+                          []):
+            yield self.runInteraction(
+                "have_seen_events",
+                have_seen_events_txn,
+                chunk,
+            )
+        defer.returnValue(results)
+
+    def get_seen_events_with_rejections(self, event_ids):
+        """Given a list of event ids, check if we rejected them.
+
+        Args:
+            event_ids (list[str])
+
+        Returns:
+            Deferred[dict[str, str|None):
+                Has an entry for each event id we already have seen. Maps to
+                the rejected reason string if we rejected the event, else maps
+                to None.
         """
         if not event_ids:
             return defer.succeed({})
@@ -1330,295 +1491,7 @@ class EventsStore(SQLBaseStore):
 
             return res
 
-        return self.runInteraction(
-            "have_events", f,
-        )
-
-    @defer.inlineCallbacks
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False, allow_rejected=False):
-        if not event_ids:
-            defer.returnValue([])
-
-        event_id_list = event_ids
-        event_ids = set(event_ids)
-
-        event_entry_map = self._get_events_from_cache(
-            event_ids,
-            allow_rejected=allow_rejected,
-        )
-
-        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
-        if missing_events_ids:
-            missing_events = yield self._enqueue_events(
-                missing_events_ids,
-                check_redacted=check_redacted,
-                allow_rejected=allow_rejected,
-            )
-
-            event_entry_map.update(missing_events)
-
-        events = []
-        for event_id in event_id_list:
-            entry = event_entry_map.get(event_id, None)
-            if not entry:
-                continue
-
-            if allow_rejected or not entry.event.rejected_reason:
-                if check_redacted and entry.redacted_event:
-                    event = entry.redacted_event
-                else:
-                    event = entry.event
-
-                events.append(event)
-
-                if get_prev_content:
-                    if "replaces_state" in event.unsigned:
-                        prev = yield self.get_event(
-                            event.unsigned["replaces_state"],
-                            get_prev_content=False,
-                            allow_none=True,
-                        )
-                        if prev:
-                            event.unsigned = dict(event.unsigned)
-                            event.unsigned["prev_content"] = prev.content
-                            event.unsigned["prev_sender"] = prev.sender
-
-        defer.returnValue(events)
-
-    def _invalidate_get_event_cache(self, event_id):
-            self._get_event_cache.invalidate((event_id,))
-
-    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
-        """Fetch events from the caches
-
-        Args:
-            events (list(str)): list of event_ids to fetch
-            allow_rejected (bool): Whether to teturn events that were rejected
-            update_metrics (bool): Whether to update the cache hit ratio metrics
-
-        Returns:
-            dict of event_id -> _EventCacheEntry for each event_id in cache. If
-            allow_rejected is `False` then there will still be an entry but it
-            will be `None`
-        """
-        event_map = {}
-
-        for event_id in events:
-            ret = self._get_event_cache.get(
-                (event_id,), None,
-                update_metrics=update_metrics,
-            )
-            if not ret:
-                continue
-
-            if allow_rejected or not ret.event.rejected_reason:
-                event_map[event_id] = ret
-            else:
-                event_map[event_id] = None
-
-        return event_map
-
-    def _do_fetch(self, conn):
-        """Takes a database connection and waits for requests for events from
-        the _event_fetch_list queue.
-        """
-        event_list = []
-        i = 0
-        while True:
-            try:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
-                            self._event_fetch_ongoing -= 1
-                            return
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
-
-                event_id_lists = zip(*event_list)[0]
-                event_ids = [
-                    item for sublist in event_id_lists for item in sublist
-                ]
-
-                rows = self._new_transaction(
-                    conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
-                )
-
-                row_dict = {
-                    r["event_id"]: r
-                    for r in rows
-                }
-
-                # We only want to resolve deferreds from the main thread
-                def fire(lst, res):
-                    for ids, d in lst:
-                        if not d.called:
-                            try:
-                                with PreserveLoggingContext():
-                                    d.callback([
-                                        res[i]
-                                        for i in ids
-                                        if i in res
-                                    ])
-                            except:
-                                logger.exception("Failed to callback")
-                with PreserveLoggingContext():
-                    reactor.callFromThread(fire, event_list, row_dict)
-            except Exception as e:
-                logger.exception("do_fetch")
-
-                # We only want to resolve deferreds from the main thread
-                def fire(evs):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(e)
-
-                if event_list:
-                    with PreserveLoggingContext():
-                        reactor.callFromThread(fire, event_list)
-
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
-        """Fetches events from the database using the _event_fetch_list. This
-        allows batch and bulk fetching of events - it allows us to fetch events
-        without having to create a new transaction for each request for events.
-        """
-        if not events:
-            defer.returnValue({})
-
-        events_d = defer.Deferred()
-        with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, events_d)
-            )
-
-            self._event_fetch_lock.notify()
-
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            with PreserveLoggingContext():
-                self.runWithConnection(
-                    self._do_fetch
-                )
-
-        logger.debug("Loading %d events", len(events))
-        with PreserveLoggingContext():
-            rows = yield events_d
-        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
-        if not allow_rejected:
-            rows[:] = [r for r in rows if not r["rejects"]]
-
-        res = yield preserve_context_over_deferred(defer.gatherResults(
-            [
-                preserve_fn(self._get_event_from_row)(
-                    row["internal_metadata"], row["json"], row["redacts"],
-                    rejected_reason=row["rejects"],
-                )
-                for row in rows
-            ],
-            consumeErrors=True
-        ))
-
-        defer.returnValue({
-            e.event.event_id: e
-            for e in res if e
-        })
-
-    def _fetch_event_rows(self, txn, events):
-        rows = []
-        N = 200
-        for i in range(1 + len(events) / N):
-            evs = events[i * N:(i + 1) * N]
-            if not evs:
-                break
-
-            sql = (
-                "SELECT "
-                " e.event_id as event_id, "
-                " e.internal_metadata,"
-                " e.json,"
-                " r.redacts as redacts,"
-                " rej.event_id as rejects "
-                " FROM event_json as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
-                " WHERE e.event_id IN (%s)"
-            ) % (",".join(["?"] * len(evs)),)
-
-            txn.execute(sql, evs)
-            rows.extend(self.cursor_to_dict(txn))
-
-        return rows
-
-    @defer.inlineCallbacks
-    def _get_event_from_row(self, internal_metadata, js, redacted,
-                            rejected_reason=None):
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if rejected_reason:
-                rejected_reason = yield self._simple_select_one_onecol(
-                    table="rejections",
-                    keyvalues={"event_id": rejected_reason},
-                    retcol="reason",
-                    desc="_get_event_from_row_rejected_reason",
-                )
-
-            original_ev = FrozenEvent(
-                d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = None
-            if redacted:
-                redacted_event = prune_event(original_ev)
-
-                redaction_id = yield self._simple_select_one_onecol(
-                    table="redactions",
-                    keyvalues={"redacts": redacted_event.event_id},
-                    retcol="event_id",
-                    desc="_get_event_from_row_redactions",
-                )
-
-                redacted_event.unsigned["redacted_by"] = redaction_id
-                # Get the redaction event.
-
-                because = yield self.get_event(
-                    redaction_id,
-                    check_redacted=False,
-                    allow_none=True,
-                )
-
-                if because:
-                    # It's fine to do add the event directly, since get_pdu_json
-                    # will serialise this field correctly
-                    redacted_event.unsigned["redacted_because"] = because
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev,
-                redacted_event=redacted_event,
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        defer.returnValue(cache_entry)
+        return self.runInteraction("get_rejection_reasons", f)
 
     @defer.inlineCallbacks
     def count_daily_messages(self):
@@ -1778,7 +1651,7 @@ class EventsStore(SQLBaseStore):
 
             chunks = [
                 event_ids[i:i + 100]
-                for i in xrange(0, len(event_ids), 100)
+                for i in range(0, len(event_ids), 100)
             ]
             for chunk in chunks:
                 ev_rows = self._simple_select_many_txn(
@@ -2005,15 +1878,32 @@ class EventsStore(SQLBaseStore):
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
-    def delete_old_state(self, room_id, topological_ordering):
+    def purge_history(
+        self, room_id, token, delete_local_events,
+    ):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
+
+            token (str): A topological token to delete events before
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
+        """
+
         return self.runInteraction(
-            "delete_old_state",
-            self._delete_old_state_txn, room_id, topological_ordering
+            "purge_history",
+            self._purge_history_txn, room_id, token,
+            delete_local_events,
         )
 
-    def _delete_old_state_txn(self, txn, room_id, topological_ordering):
-        """Deletes old room state
-        """
+    def _purge_history_txn(
+        self, txn, room_id, token_str, delete_local_events,
+    ):
+        token = RoomStreamToken.parse(token_str)
 
         # Tables that should be pruned:
         #     event_auth
@@ -2035,6 +1925,37 @@ class EventsStore(SQLBaseStore):
         #     state_groups
         #     state_groups_state
 
+        # we will build a temporary table listing the events so that we don't
+        # have to keep shovelling the list back and forth across the
+        # connection. Annoyingly the python sqlite driver commits the
+        # transaction on CREATE, so let's do this first.
+        #
+        # furthermore, we might already have the table from a previous (failed)
+        # purge attempt, so let's drop the table first.
+
+        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+        txn.execute(
+            "CREATE TEMPORARY TABLE events_to_purge ("
+            "    event_id TEXT NOT NULL,"
+            "    should_delete BOOLEAN NOT NULL"
+            ")"
+        )
+
+        # create an index on should_delete because later we'll be looking for
+        # the should_delete / shouldn't_delete subsets
+        txn.execute(
+            "CREATE INDEX events_to_purge_should_delete"
+            " ON events_to_purge(should_delete)",
+        )
+
+        # We do joins against events_to_purge for e.g. calculating state
+        # groups to purge, etc., so lets make an index.
+        txn.execute(
+            "CREATE INDEX events_to_purge_id"
+            " ON events_to_purge(event_id)",
+        )
+
         # First ensure that we're not about to delete all the forward extremeties
         txn.execute(
             "SELECT e.event_id, e.depth FROM events as e "
@@ -2047,7 +1968,7 @@ class EventsStore(SQLBaseStore):
         rows = txn.fetchall()
         max_depth = max(row[0] for row in rows)
 
-        if max_depth <= topological_ordering:
+        if max_depth <= token.topological:
             # We need to ensure we don't delete all the events from the datanase
             # otherwise we wouldn't be able to send any events (due to not
             # having any backwards extremeties)
@@ -2055,42 +1976,48 @@ class EventsStore(SQLBaseStore):
                 400, "topological_ordering is greater than forward extremeties"
             )
 
-        logger.debug("[purge] looking for events to delete")
+        logger.info("[purge] looking for events to delete")
+
+        should_delete_expr = "state_key IS NULL"
+        should_delete_params = ()
+        if not delete_local_events:
+            should_delete_expr += " AND event_id NOT LIKE ?"
+            should_delete_params += ("%:" + self.hs.hostname, )
+
+        should_delete_params += (room_id, token.topological)
 
         txn.execute(
-            "SELECT event_id, state_key FROM events"
-            " LEFT JOIN state_events USING (room_id, event_id)"
-            " WHERE room_id = ? AND topological_ordering < ?",
-            (room_id, topological_ordering,)
+            "INSERT INTO events_to_purge"
+            " SELECT event_id, %s"
+            " FROM events AS e LEFT JOIN state_events USING (event_id)"
+            " WHERE e.room_id = ? AND topological_ordering < ?" % (
+                should_delete_expr,
+            ),
+            should_delete_params,
+        )
+        txn.execute(
+            "SELECT event_id, should_delete FROM events_to_purge"
         )
         event_rows = txn.fetchall()
-
-        to_delete = [
-            (event_id,) for event_id, state_key in event_rows
-            if state_key is None and not self.hs.is_mine_id(event_id)
-        ]
         logger.info(
-            "[purge] found %i events before cutoff, of which %i are remote"
-            " non-state events to delete", len(event_rows), len(to_delete))
-
-        for event_id, state_key in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows), sum(1 for e in event_rows if e[1]),
+        )
 
-        logger.debug("[purge] Finding new backward extremities")
+        logger.info("[purge] Finding new backward extremities")
 
         # We calculate the new entries for the backward extremeties by finding
-        # all events that point to events that are to be purged
+        # events to be purged that are pointed to by events we're not going to
+        # purge.
         txn.execute(
-            "SELECT DISTINCT e.event_id FROM events as e"
-            " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
-            " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
-            " WHERE e.room_id = ? AND e.topological_ordering < ?"
-            " AND e2.topological_ordering >= ?",
-            (room_id, topological_ordering, topological_ordering)
+            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+            " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+            " WHERE ep2.event_id IS NULL",
         )
         new_backwards_extrems = txn.fetchall()
 
-        logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
 
         txn.execute(
             "DELETE FROM event_backward_extremities WHERE room_id = ?",
@@ -2106,34 +2033,39 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
-        logger.debug("[purge] finding redundant state groups")
+        logger.info("[purge] finding redundant state groups")
 
         # Get all state groups that are only referenced by events that are
         # to be deleted.
-        txn.execute(
-            "SELECT state_group FROM event_to_state_groups"
-            " INNER JOIN events USING (event_id)"
-            " WHERE state_group IN ("
-            "   SELECT DISTINCT state_group FROM events"
-            "   INNER JOIN event_to_state_groups USING (event_id)"
-            "   WHERE room_id = ? AND topological_ordering < ?"
-            " )"
-            " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
-            (room_id, topological_ordering, topological_ordering)
-        )
+        # This works by first getting state groups that we may want to delete,
+        # joining against event_to_state_groups to get events that use that
+        # state group, then left joining against events_to_purge again. Any
+        # state group where the left join produce *no nulls* are referenced
+        # only by events that are going to be purged.
+        txn.execute("""
+            SELECT state_group FROM
+            (
+                SELECT DISTINCT state_group FROM events_to_purge
+                INNER JOIN event_to_state_groups USING (event_id)
+            ) AS sp
+            INNER JOIN event_to_state_groups USING (state_group)
+            LEFT JOIN events_to_purge AS ep USING (event_id)
+            GROUP BY state_group
+            HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
+        """)
 
         state_rows = txn.fetchall()
-        logger.debug("[purge] found %i redundant state groups", len(state_rows))
+        logger.info("[purge] found %i redundant state groups", len(state_rows))
 
         # make a set of the redundant state groups, so that we can look them up
         # efficiently
         state_groups_to_delete = set([sg for sg, in state_rows])
 
         # Now we get all the state groups that rely on these state groups
-        logger.debug("[purge] finding state groups which depend on redundant"
-                     " state groups")
+        logger.info("[purge] finding state groups which depend on redundant"
+                    " state groups")
         remaining_state_groups = []
-        for i in xrange(0, len(state_rows), 100):
+        for i in range(0, len(state_rows), 100):
             chunk = [sg for sg, in state_rows[i:i + 100]]
             # look for state groups whose prev_state_group is one we are about
             # to delete
@@ -2156,7 +2088,7 @@ class EventsStore(SQLBaseStore):
         # Now we turn the state groups that reference to-be-deleted state
         # groups to non delta versions.
         for sg in remaining_state_groups:
-            logger.debug("[purge] de-delta-ing remaining state group %s", sg)
+            logger.info("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
                 txn, [sg], types=None
             )
@@ -2189,11 +2121,11 @@ class EventsStore(SQLBaseStore):
                         "state_key": key[1],
                         "event_id": state_id,
                     }
-                    for key, state_id in curr_state.iteritems()
+                    for key, state_id in iteritems(curr_state)
                 ],
             )
 
-        logger.debug("[purge] removing redundant state groups")
+        logger.info("[purge] removing redundant state groups")
         txn.executemany(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             state_rows
@@ -2203,18 +2135,15 @@ class EventsStore(SQLBaseStore):
             state_rows
         )
 
-        # Delete all non-state
-        logger.debug("[purge] removing events from event_to_state_groups")
-        txn.executemany(
-            "DELETE FROM event_to_state_groups WHERE event_id = ?",
-            [(event_id,) for event_id, _ in event_rows]
-        )
-
-        logger.debug("[purge] updating room_depth")
+        logger.info("[purge] removing events from event_to_state_groups")
         txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (topological_ordering, room_id,)
+            "DELETE FROM event_to_state_groups "
+            "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (
+                event_id,
+            ))
 
         # Delete all remote non-state events
         for table in (
@@ -2226,28 +2155,75 @@ class EventsStore(SQLBaseStore):
             "event_edge_hashes",
             "event_edges",
             "event_forward_extremities",
-            "event_push_actions",
             "event_reference_hashes",
             "event_search",
             "event_signatures",
             "rejections",
         ):
-            logger.debug("[purge] removing remote non-state events from %s", table)
+            logger.info("[purge] removing events from %s", table)
 
-            txn.executemany(
-                "DELETE FROM %s WHERE event_id = ?" % (table,),
-                to_delete
+            txn.execute(
+                "DELETE FROM %s WHERE event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+            )
+
+        # event_push_actions lacks an index on event_id, and has one on
+        # (room_id, event_id) instead.
+        for table in (
+            "event_push_actions",
+        ):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+                (room_id, )
             )
 
         # Mark all state and own events as outliers
-        logger.debug("[purge] marking remaining events as outliers")
-        txn.executemany(
+        logger.info("[purge] marking remaining events as outliers")
+        txn.execute(
             "UPDATE events SET outlier = ?"
-            " WHERE event_id = ?",
-            [
-                (True, event_id,) for event_id, state_key in event_rows
-                if state_key is not None or self.hs.is_mine_id(event_id)
-            ]
+            " WHERE event_id IN ("
+            "    SELECT event_id FROM events_to_purge "
+            "    WHERE NOT should_delete"
+            ")",
+            (True,),
+        )
+
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        #
+        # We do this by calculating the minimum depth of the backwards
+        # extremities. However, the events in event_backward_extremities
+        # are ones we don't have yet so we need to look at the events that
+        # point to it via event_edges table.
+        txn.execute("""
+            SELECT COALESCE(MIN(depth), 0)
+            FROM event_backward_extremities AS eb
+            INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+            INNER JOIN events AS e ON e.event_id = eg.event_id
+            WHERE eb.room_id = ?
+        """, (room_id,))
+        min_depth, = txn.fetchone()
+
+        logger.info("[purge] updating room_depth to %d", min_depth)
+
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (min_depth, room_id,)
+        )
+
+        # finally, drop the temp table. this will commit the txn in sqlite,
+        # so make sure to keep this actually last.
+        txn.execute(
+            "DROP TABLE events_to_purge"
         )
 
         logger.info("[purge] done")
@@ -2260,7 +2236,7 @@ class EventsStore(SQLBaseStore):
         to_2, so_2 = yield self._get_event_ordering(event_id2)
         defer.returnValue((to_1, so_1) > (to_2, so_2))
 
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks(max_entries=5000)
     def _get_event_ordering(self, event_id):
         res = yield self._simple_select_one(
             table="events",
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..f28239a808
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,436 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from collections import namedtuple
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+# these are only included to make the type annotations work
+from synapse.events import EventBase  # noqa: F401
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events.utils import prune_event
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import (
+    LoggingContext,
+    PreserveLoggingContext,
+    make_deferred_yieldable,
+    run_in_background,
+)
+from synapse.util.metrics import Measure
+
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+    def get_received_ts(self, event_id):
+        """Get received_ts (when it was persisted) for the event.
+
+        Raises an exception for unknown events.
+
+        Args:
+            event_id (str)
+
+        Returns:
+            Deferred[int|None]: Timestamp in milliseconds, or None for events
+            that were persisted before received_ts was implemented.
+        """
+        return self._simple_select_one_onecol(
+            table="events",
+            keyvalues={
+                "event_id": event_id,
+            },
+            retcol="received_ts",
+            desc="get_received_ts",
+        )
+
+    @defer.inlineCallbacks
+    def get_event(self, event_id, check_redacted=True,
+                  get_prev_content=False, allow_rejected=False,
+                  allow_none=False):
+        """Get an event from the database by event_id.
+
+        Args:
+            event_id (str): The event_id of the event to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+            allow_none (bool): If True, return None if no event found, if
+                False throw an exception.
+
+        Returns:
+            Deferred : A FrozenEvent.
+        """
+        events = yield self._get_events(
+            [event_id],
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        if not events and not allow_none:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue(events[0] if events else None)
+
+    @defer.inlineCallbacks
+    def get_events(self, event_ids, check_redacted=True,
+                   get_prev_content=False, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred : Dict from event_id to event.
+        """
+        events = yield self._get_events(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        defer.returnValue({e.event_id: e for e in events})
+
+    @defer.inlineCallbacks
+    def _get_events(self, event_ids, check_redacted=True,
+                    get_prev_content=False, allow_rejected=False):
+        if not event_ids:
+            defer.returnValue([])
+
+        event_id_list = event_ids
+        event_ids = set(event_ids)
+
+        event_entry_map = self._get_events_from_cache(
+            event_ids,
+            allow_rejected=allow_rejected,
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+        if missing_events_ids:
+            log_ctx = LoggingContext.current_context()
+            log_ctx.record_event_fetch(len(missing_events_ids))
+
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                allow_rejected=allow_rejected,
+            )
+
+            event_entry_map.update(missing_events)
+
+        events = []
+        for event_id in event_id_list:
+            entry = event_entry_map.get(event_id, None)
+            if not entry:
+                continue
+
+            if allow_rejected or not entry.event.rejected_reason:
+                if check_redacted and entry.redacted_event:
+                    event = entry.redacted_event
+                else:
+                    event = entry.event
+
+                events.append(event)
+
+                if get_prev_content:
+                    if "replaces_state" in event.unsigned:
+                        prev = yield self.get_event(
+                            event.unsigned["replaces_state"],
+                            get_prev_content=False,
+                            allow_none=True,
+                        )
+                        if prev:
+                            event.unsigned = dict(event.unsigned)
+                            event.unsigned["prev_content"] = prev.content
+                            event.unsigned["prev_sender"] = prev.sender
+
+        defer.returnValue(events)
+
+    def _invalidate_get_event_cache(self, event_id):
+            self._get_event_cache.invalidate((event_id,))
+
+    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+        """Fetch events from the caches
+
+        Args:
+            events (list(str)): list of event_ids to fetch
+            allow_rejected (bool): Whether to teturn events that were rejected
+            update_metrics (bool): Whether to update the cache hit ratio metrics
+
+        Returns:
+            dict of event_id -> _EventCacheEntry for each event_id in cache. If
+            allow_rejected is `False` then there will still be an entry but it
+            will be `None`
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = self._get_event_cache.get(
+                (event_id,), None,
+                update_metrics=update_metrics,
+            )
+            if not ret:
+                continue
+
+            if allow_rejected or not ret.event.rejected_reason:
+                event_map[event_id] = ret
+            else:
+                event_map[event_id] = None
+
+        return event_map
+
+    def _do_fetch(self, conn):
+        """Takes a database connection and waits for requests for events from
+        the _event_fetch_list queue.
+        """
+        i = 0
+        while True:
+            with self._event_fetch_lock:
+                event_list = self._event_fetch_list
+                self._event_fetch_list = []
+
+                if not event_list:
+                    single_threaded = self.database_engine.single_threaded
+                    if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                        self._event_fetch_ongoing -= 1
+                        return
+                    else:
+                        self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                        i += 1
+                        continue
+                i = 0
+
+            self._fetch_event_list(conn, event_list)
+
+    def _fetch_event_list(self, conn, event_list):
+        """Handle a load of requests from the _event_fetch_list queue
+
+        Args:
+            conn (twisted.enterprise.adbapi.Connection): database connection
+
+            event_list (list[Tuple[list[str], Deferred]]):
+                The fetch requests. Each entry consists of a list of event
+                ids to be fetched, and a deferred to be completed once the
+                events have been fetched.
+
+        """
+        with Measure(self._clock, "_fetch_event_list"):
+            try:
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+
+                rows = self._new_transaction(
+                    conn, "do_fetch", [], [],
+                    self._fetch_event_rows, event_ids,
+                )
+
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                # We only want to resolve deferreds from the main thread
+                def fire(lst, res):
+                    for ids, d in lst:
+                        if not d.called:
+                            try:
+                                with PreserveLoggingContext():
+                                    d.callback([
+                                        res[i]
+                                        for i in ids
+                                        if i in res
+                                    ])
+                            except Exception:
+                                logger.exception("Failed to callback")
+                with PreserveLoggingContext():
+                    self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
+            except Exception as e:
+                logger.exception("do_fetch")
+
+                # We only want to resolve deferreds from the main thread
+                def fire(evs):
+                    for _, d in evs:
+                        if not d.called:
+                            with PreserveLoggingContext():
+                                d.errback(e)
+
+                with PreserveLoggingContext():
+                    self.hs.get_reactor().callFromThread(fire, event_list)
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+        """Fetches events from the database using the _event_fetch_list. This
+        allows batch and bulk fetching of events - it allows us to fetch events
+        without having to create a new transaction for each request for events.
+        """
+        if not events:
+            defer.returnValue({})
+
+        events_d = defer.Deferred()
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, events_d)
+            )
+
+            self._event_fetch_lock.notify()
+
+            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+                self._event_fetch_ongoing += 1
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            run_as_background_process(
+                "fetch_events",
+                self.runWithConnection,
+                self._do_fetch,
+            )
+
+        logger.debug("Loading %d events", len(events))
+        with PreserveLoggingContext():
+            rows = yield events_d
+        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+        if not allow_rejected:
+            rows[:] = [r for r in rows if not r["rejects"]]
+
+        res = yield make_deferred_yieldable(defer.gatherResults(
+            [
+                run_in_background(
+                    self._get_event_from_row,
+                    row["internal_metadata"], row["json"], row["redacts"],
+                    rejected_reason=row["rejects"],
+                )
+                for row in rows
+            ],
+            consumeErrors=True
+        ))
+
+        defer.returnValue({
+            e.event.event_id: e
+            for e in res if e
+        })
+
+    def _fetch_event_rows(self, txn, events):
+        rows = []
+        N = 200
+        for i in range(1 + len(events) // N):
+            evs = events[i * N:(i + 1) * N]
+            if not evs:
+                break
+
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " e.internal_metadata,"
+                " e.json,"
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"] * len(evs)),)
+
+            txn.execute(sql, evs)
+            rows.extend(self.cursor_to_dict(txn))
+
+        return rows
+
+    @defer.inlineCallbacks
+    def _get_event_from_row(self, internal_metadata, js, redacted,
+                            rejected_reason=None):
+        with Measure(self._clock, "_get_event_from_row"):
+            d = json.loads(js)
+            internal_metadata = json.loads(internal_metadata)
+
+            if rejected_reason:
+                rejected_reason = yield self._simple_select_one_onecol(
+                    table="rejections",
+                    keyvalues={"event_id": rejected_reason},
+                    retcol="reason",
+                    desc="_get_event_from_row_rejected_reason",
+                )
+
+            original_ev = FrozenEvent(
+                d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
+            )
+
+            redacted_event = None
+            if redacted:
+                redacted_event = prune_event(original_ev)
+
+                redaction_id = yield self._simple_select_one_onecol(
+                    table="redactions",
+                    keyvalues={"redacts": redacted_event.event_id},
+                    retcol="event_id",
+                    desc="_get_event_from_row_redactions",
+                )
+
+                redacted_event.unsigned["redacted_by"] = redaction_id
+                # Get the redaction event.
+
+                because = yield self.get_event(
+                    redaction_id,
+                    check_redacted=False,
+                    allow_none=True,
+                )
+
+                if because:
+                    # It's fine to do add the event directly, since get_pdu_json
+                    # will serialise this field correctly
+                    redacted_event.unsigned["redacted_because"] = because
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev,
+                redacted_event=redacted_event,
+            )
+
+            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+        defer.returnValue(cache_entry)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 78b1e30945..2d5896c5b4 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from canonicaljson import encode_canonical_json, json
+
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
-from synapse.api.errors import SynapseError, Codes
+from synapse.api.errors import Codes, SynapseError
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
-from canonicaljson import encode_canonical_json
-import simplejson as json
+from ._base import SQLBaseStore
 
 
 class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
             desc="get_user_filter",
         )
 
-        defer.returnValue(json.loads(str(def_json).decode("utf-8")))
+        defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
 
     def add_user_filter(self, user_localpart, user_filter):
         def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
new file mode 100644
index 0000000000..592d1b4c2a
--- /dev/null
+++ b/synapse/storage/group_server.py
@@ -0,0 +1,1252 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+
+from ._base import SQLBaseStore
+
+# The category ID for the "default" category. We don't store as null in the
+# database to avoid the fun of null != null
+_DEFAULT_CATEGORY_ID = ""
+_DEFAULT_ROLE_ID = ""
+
+
+class GroupServerStore(SQLBaseStore):
+    def set_group_join_policy(self, group_id, join_policy):
+        """Set the join policy of a group.
+
+        join_policy can be one of:
+         * "invite"
+         * "open"
+        """
+        return self._simple_update_one(
+            table="groups",
+            keyvalues={
+                "group_id": group_id,
+            },
+            updatevalues={
+                "join_policy": join_policy,
+            },
+            desc="set_group_join_policy",
+        )
+
+    def get_group(self, group_id):
+        return self._simple_select_one(
+            table="groups",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcols=(
+                "name", "short_description", "long_description",
+                "avatar_url", "is_public", "join_policy",
+            ),
+            allow_none=True,
+            desc="get_group",
+        )
+
+    def get_users_in_group(self, group_id, include_private=False):
+        # TODO: Pagination
+
+        keyvalues = {
+            "group_id": group_id,
+        }
+        if not include_private:
+            keyvalues["is_public"] = True
+
+        return self._simple_select_list(
+            table="group_users",
+            keyvalues=keyvalues,
+            retcols=("user_id", "is_public", "is_admin",),
+            desc="get_users_in_group",
+        )
+
+    def get_invited_users_in_group(self, group_id):
+        # TODO: Pagination
+
+        return self._simple_select_onecol(
+            table="group_invites",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcol="user_id",
+            desc="get_invited_users_in_group",
+        )
+
+    def get_rooms_in_group(self, group_id, include_private=False):
+        # TODO: Pagination
+
+        keyvalues = {
+            "group_id": group_id,
+        }
+        if not include_private:
+            keyvalues["is_public"] = True
+
+        return self._simple_select_list(
+            table="group_rooms",
+            keyvalues=keyvalues,
+            retcols=("room_id", "is_public",),
+            desc="get_rooms_in_group",
+        )
+
+    def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+        """Get the rooms and categories that should be included in a summary request
+
+        Returns ([rooms], [categories])
+        """
+        def _get_rooms_for_summary_txn(txn):
+            keyvalues = {
+                "group_id": group_id,
+            }
+            if not include_private:
+                keyvalues["is_public"] = True
+
+            sql = """
+                SELECT room_id, is_public, category_id, room_order
+                FROM group_summary_rooms
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            rooms = [
+                {
+                    "room_id": row[0],
+                    "is_public": row[1],
+                    "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
+                    "order": row[3],
+                }
+                for row in txn
+            ]
+
+            sql = """
+                SELECT category_id, is_public, profile, cat_order
+                FROM group_summary_room_categories
+                INNER JOIN group_room_categories USING (group_id, category_id)
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            categories = {
+                row[0]: {
+                    "is_public": row[1],
+                    "profile": json.loads(row[2]),
+                    "order": row[3],
+                }
+                for row in txn
+            }
+
+            return rooms, categories
+        return self.runInteraction(
+            "get_rooms_for_summary", _get_rooms_for_summary_txn
+        )
+
+    def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
+        return self.runInteraction(
+            "add_room_to_summary", self._add_room_to_summary_txn,
+            group_id, room_id, category_id, order, is_public,
+        )
+
+    def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
+                                 is_public):
+        """Add (or update) room's entry in summary.
+
+        Args:
+            group_id (str)
+            room_id (str)
+            category_id (str): If not None then adds the category to the end of
+                the summary if its not already there. [Optional]
+            order (int): If not None inserts the room at that position, e.g.
+                an order of 1 will put the room first. Otherwise, the room gets
+                added to the end.
+        """
+        room_in_group = self._simple_select_one_onecol_txn(
+            txn,
+            table="group_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "room_id": room_id,
+            },
+            retcol="room_id",
+            allow_none=True,
+        )
+        if not room_in_group:
+            raise SynapseError(400, "room not in group")
+
+        if category_id is None:
+            category_id = _DEFAULT_CATEGORY_ID
+        else:
+            cat_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_room_categories",
+                keyvalues={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not cat_exists:
+                raise SynapseError(400, "Category doesn't exist")
+
+            # TODO: Check category is part of summary already
+            cat_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_summary_room_categories",
+                keyvalues={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not cat_exists:
+                # If not, add it with an order larger than all others
+                txn.execute("""
+                    INSERT INTO group_summary_room_categories
+                    (group_id, category_id, cat_order)
+                    SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
+                    FROM group_summary_room_categories
+                    WHERE group_id = ? AND category_id = ?
+                """, (group_id, category_id, group_id, category_id))
+
+        existing = self._simple_select_one_txn(
+            txn,
+            table="group_summary_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "room_id": room_id,
+                "category_id": category_id,
+            },
+            retcols=("room_order", "is_public",),
+            allow_none=True,
+        )
+
+        if order is not None:
+            # Shuffle other room orders that come after the given order
+            sql = """
+                UPDATE group_summary_rooms SET room_order = room_order + 1
+                WHERE group_id = ? AND category_id = ? AND room_order >= ?
+            """
+            txn.execute(sql, (group_id, category_id, order,))
+        elif not existing:
+            sql = """
+                SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
+                WHERE group_id = ? AND category_id = ?
+            """
+            txn.execute(sql, (group_id, category_id,))
+            order, = txn.fetchone()
+
+        if existing:
+            to_update = {}
+            if order is not None:
+                to_update["room_order"] = order
+            if is_public is not None:
+                to_update["is_public"] = is_public
+            self._simple_update_txn(
+                txn,
+                table="group_summary_rooms",
+                keyvalues={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                    "room_id": room_id,
+                },
+                values=to_update,
+            )
+        else:
+            if is_public is None:
+                is_public = True
+
+            self._simple_insert_txn(
+                txn,
+                table="group_summary_rooms",
+                values={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                    "room_id": room_id,
+                    "room_order": order,
+                    "is_public": is_public,
+                },
+            )
+
+    def remove_room_from_summary(self, group_id, room_id, category_id):
+        if category_id is None:
+            category_id = _DEFAULT_CATEGORY_ID
+
+        return self._simple_delete(
+            table="group_summary_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+                "room_id": room_id,
+            },
+            desc="remove_room_from_summary",
+        )
+
+    @defer.inlineCallbacks
+    def get_group_categories(self, group_id):
+        rows = yield self._simple_select_list(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcols=("category_id", "is_public", "profile"),
+            desc="get_group_categories",
+        )
+
+        defer.returnValue({
+            row["category_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
+            }
+            for row in rows
+        })
+
+    @defer.inlineCallbacks
+    def get_group_category(self, group_id, category_id):
+        category = yield self._simple_select_one(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+            },
+            retcols=("is_public", "profile"),
+            desc="get_group_category",
+        )
+
+        category["profile"] = json.loads(category["profile"])
+
+        defer.returnValue(category)
+
+    def upsert_group_category(self, group_id, category_id, profile, is_public):
+        """Add/update room category for group
+        """
+        insertion_values = {}
+        update_values = {"category_id": category_id}  # This cannot be empty
+
+        if profile is None:
+            insertion_values["profile"] = "{}"
+        else:
+            update_values["profile"] = json.dumps(profile)
+
+        if is_public is None:
+            insertion_values["is_public"] = True
+        else:
+            update_values["is_public"] = is_public
+
+        return self._simple_upsert(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+            },
+            values=update_values,
+            insertion_values=insertion_values,
+            desc="upsert_group_category",
+        )
+
+    def remove_group_category(self, group_id, category_id):
+        return self._simple_delete(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+            },
+            desc="remove_group_category",
+        )
+
+    @defer.inlineCallbacks
+    def get_group_roles(self, group_id):
+        rows = yield self._simple_select_list(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcols=("role_id", "is_public", "profile"),
+            desc="get_group_roles",
+        )
+
+        defer.returnValue({
+            row["role_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
+            }
+            for row in rows
+        })
+
+    @defer.inlineCallbacks
+    def get_group_role(self, group_id, role_id):
+        role = yield self._simple_select_one(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+            },
+            retcols=("is_public", "profile"),
+            desc="get_group_role",
+        )
+
+        role["profile"] = json.loads(role["profile"])
+
+        defer.returnValue(role)
+
+    def upsert_group_role(self, group_id, role_id, profile, is_public):
+        """Add/remove user role
+        """
+        insertion_values = {}
+        update_values = {"role_id": role_id}  # This cannot be empty
+
+        if profile is None:
+            insertion_values["profile"] = "{}"
+        else:
+            update_values["profile"] = json.dumps(profile)
+
+        if is_public is None:
+            insertion_values["is_public"] = True
+        else:
+            update_values["is_public"] = is_public
+
+        return self._simple_upsert(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+            },
+            values=update_values,
+            insertion_values=insertion_values,
+            desc="upsert_group_role",
+        )
+
+    def remove_group_role(self, group_id, role_id):
+        return self._simple_delete(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+            },
+            desc="remove_group_role",
+        )
+
+    def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
+        return self.runInteraction(
+            "add_user_to_summary", self._add_user_to_summary_txn,
+            group_id, user_id, role_id, order, is_public,
+        )
+
+    def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
+                                 is_public):
+        """Add (or update) user's entry in summary.
+
+        Args:
+            group_id (str)
+            user_id (str)
+            role_id (str): If not None then adds the role to the end of
+                the summary if its not already there. [Optional]
+            order (int): If not None inserts the user at that position, e.g.
+                an order of 1 will put the user first. Otherwise, the user gets
+                added to the end.
+        """
+        user_in_group = self._simple_select_one_onecol_txn(
+            txn,
+            table="group_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="user_id",
+            allow_none=True,
+        )
+        if not user_in_group:
+            raise SynapseError(400, "user not in group")
+
+        if role_id is None:
+            role_id = _DEFAULT_ROLE_ID
+        else:
+            role_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_roles",
+                keyvalues={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not role_exists:
+                raise SynapseError(400, "Role doesn't exist")
+
+            # TODO: Check role is part of the summary already
+            role_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_summary_roles",
+                keyvalues={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not role_exists:
+                # If not, add it with an order larger than all others
+                txn.execute("""
+                    INSERT INTO group_summary_roles
+                    (group_id, role_id, role_order)
+                    SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
+                    FROM group_summary_roles
+                    WHERE group_id = ? AND role_id = ?
+                """, (group_id, role_id, group_id, role_id))
+
+        existing = self._simple_select_one_txn(
+            txn,
+            table="group_summary_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+                "role_id": role_id,
+            },
+            retcols=("user_order", "is_public",),
+            allow_none=True,
+        )
+
+        if order is not None:
+            # Shuffle other users orders that come after the given order
+            sql = """
+                UPDATE group_summary_users SET user_order = user_order + 1
+                WHERE group_id = ? AND role_id = ? AND user_order >= ?
+            """
+            txn.execute(sql, (group_id, role_id, order,))
+        elif not existing:
+            sql = """
+                SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
+                WHERE group_id = ? AND role_id = ?
+            """
+            txn.execute(sql, (group_id, role_id,))
+            order, = txn.fetchone()
+
+        if existing:
+            to_update = {}
+            if order is not None:
+                to_update["user_order"] = order
+            if is_public is not None:
+                to_update["is_public"] = is_public
+            self._simple_update_txn(
+                txn,
+                table="group_summary_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                    "user_id": user_id,
+                },
+                values=to_update,
+            )
+        else:
+            if is_public is None:
+                is_public = True
+
+            self._simple_insert_txn(
+                txn,
+                table="group_summary_users",
+                values={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                    "user_id": user_id,
+                    "user_order": order,
+                    "is_public": is_public,
+                },
+            )
+
+    def remove_user_from_summary(self, group_id, user_id, role_id):
+        if role_id is None:
+            role_id = _DEFAULT_ROLE_ID
+
+        return self._simple_delete(
+            table="group_summary_users",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+                "user_id": user_id,
+            },
+            desc="remove_user_from_summary",
+        )
+
+    def get_users_for_summary_by_role(self, group_id, include_private=False):
+        """Get the users and roles that should be included in a summary request
+
+        Returns ([users], [roles])
+        """
+        def _get_users_for_summary_txn(txn):
+            keyvalues = {
+                "group_id": group_id,
+            }
+            if not include_private:
+                keyvalues["is_public"] = True
+
+            sql = """
+                SELECT user_id, is_public, role_id, user_order
+                FROM group_summary_users
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            users = [
+                {
+                    "user_id": row[0],
+                    "is_public": row[1],
+                    "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+                    "order": row[3],
+                }
+                for row in txn
+            ]
+
+            sql = """
+                SELECT role_id, is_public, profile, role_order
+                FROM group_summary_roles
+                INNER JOIN group_roles USING (group_id, role_id)
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            roles = {
+                row[0]: {
+                    "is_public": row[1],
+                    "profile": json.loads(row[2]),
+                    "order": row[3],
+                }
+                for row in txn
+            }
+
+            return users, roles
+        return self.runInteraction(
+            "get_users_for_summary_by_role", _get_users_for_summary_txn
+        )
+
+    def is_user_in_group(self, user_id, group_id):
+        return self._simple_select_one_onecol(
+            table="group_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="user_id",
+            allow_none=True,
+            desc="is_user_in_group",
+        ).addCallback(lambda r: bool(r))
+
+    def is_user_admin_in_group(self, group_id, user_id):
+        return self._simple_select_one_onecol(
+            table="group_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="is_admin",
+            allow_none=True,
+            desc="is_user_admin_in_group",
+        )
+
+    def add_group_invite(self, group_id, user_id):
+        """Record that the group server has invited a user
+        """
+        return self._simple_insert(
+            table="group_invites",
+            values={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            desc="add_group_invite",
+        )
+
+    def is_user_invited_to_local_group(self, group_id, user_id):
+        """Has the group server invited a user?
+        """
+        return self._simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="user_id",
+            desc="is_user_invited_to_local_group",
+            allow_none=True,
+        )
+
+    def get_users_membership_info_in_group(self, group_id, user_id):
+        """Get a dict describing the membership of a user in a group.
+
+        Example if joined:
+
+            {
+                "membership": "join",
+                "is_public": True,
+                "is_privileged": False,
+            }
+
+        Returns an empty dict if the user is not join/invite/etc
+        """
+        def _get_users_membership_in_group_txn(txn):
+            row = self._simple_select_one_txn(
+                txn,
+                table="group_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+                retcols=("is_admin", "is_public"),
+                allow_none=True,
+            )
+
+            if row:
+                return {
+                    "membership": "join",
+                    "is_public": row["is_public"],
+                    "is_privileged": row["is_admin"],
+                }
+
+            row = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_invites",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+                retcol="user_id",
+                allow_none=True,
+            )
+
+            if row:
+                return {
+                    "membership": "invite",
+                }
+
+            return {}
+
+        return self.runInteraction(
+            "get_users_membership_info_in_group", _get_users_membership_in_group_txn,
+        )
+
+    def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
+                          local_attestation=None, remote_attestation=None):
+        """Add a user to the group server.
+
+        Args:
+            group_id (str)
+            user_id (str)
+            is_admin (bool)
+            is_public (bool)
+            local_attestation (dict): The attestation the GS created to give
+                to the remote server. Optional if the user and group are on the
+                same server
+            remote_attestation (dict): The attestation given to GS by remote
+                server. Optional if the user and group are on the same server
+        """
+        def _add_user_to_group_txn(txn):
+            self._simple_insert_txn(
+                txn,
+                table="group_users",
+                values={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "is_admin": is_admin,
+                    "is_public": is_public,
+                },
+            )
+
+            self._simple_delete_txn(
+                txn,
+                table="group_invites",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+
+            if local_attestation:
+                self._simple_insert_txn(
+                    txn,
+                    table="group_attestations_renewals",
+                    values={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                        "valid_until_ms": local_attestation["valid_until_ms"],
+                    },
+                )
+            if remote_attestation:
+                self._simple_insert_txn(
+                    txn,
+                    table="group_attestations_remote",
+                    values={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                        "valid_until_ms": remote_attestation["valid_until_ms"],
+                        "attestation_json": json.dumps(remote_attestation),
+                    },
+                )
+
+        return self.runInteraction(
+            "add_user_to_group", _add_user_to_group_txn
+        )
+
+    def remove_user_from_group(self, group_id, user_id):
+        def _remove_user_from_group_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="group_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_invites",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_attestations_renewals",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_attestations_remote",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_summary_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+        return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
+
+    def add_room_to_group(self, group_id, room_id, is_public):
+        return self._simple_insert(
+            table="group_rooms",
+            values={
+                "group_id": group_id,
+                "room_id": room_id,
+                "is_public": is_public,
+            },
+            desc="add_room_to_group",
+        )
+
+    def update_room_in_group_visibility(self, group_id, room_id, is_public):
+        return self._simple_update(
+            table="group_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "room_id": room_id,
+            },
+            updatevalues={
+                "is_public": is_public,
+            },
+            desc="update_room_in_group_visibility",
+        )
+
+    def remove_room_from_group(self, group_id, room_id):
+        def _remove_room_from_group_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="group_rooms",
+                keyvalues={
+                    "group_id": group_id,
+                    "room_id": room_id,
+                },
+            )
+
+            self._simple_delete_txn(
+                txn,
+                table="group_summary_rooms",
+                keyvalues={
+                    "group_id": group_id,
+                    "room_id": room_id,
+                },
+            )
+        return self.runInteraction(
+            "remove_room_from_group", _remove_room_from_group_txn,
+        )
+
+    def get_publicised_groups_for_user(self, user_id):
+        """Get all groups a user is publicising
+        """
+        return self._simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={
+                "user_id": user_id,
+                "membership": "join",
+                "is_publicised": True,
+            },
+            retcol="group_id",
+            desc="get_publicised_groups_for_user",
+        )
+
+    def update_group_publicity(self, group_id, user_id, publicise):
+        """Update whether the user is publicising their membership of the group
+        """
+        return self._simple_update_one(
+            table="local_group_membership",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            updatevalues={
+                "is_publicised": publicise,
+            },
+            desc="update_group_publicity"
+        )
+
+    @defer.inlineCallbacks
+    def register_user_group_membership(self, group_id, user_id, membership,
+                                       is_admin=False, content={},
+                                       local_attestation=None,
+                                       remote_attestation=None,
+                                       is_publicised=False,
+                                       ):
+        """Registers that a local user is a member of a (local or remote) group.
+
+        Args:
+            group_id (str)
+            user_id (str)
+            membership (str)
+            is_admin (bool)
+            content (dict): Content of the membership, e.g. includes the inviter
+                if the user has been invited.
+            local_attestation (dict): If remote group then store the fact that we
+                have given out an attestation, else None.
+            remote_attestation (dict): If remote group then store the remote
+                attestation from the group, else None.
+        """
+        def _register_user_group_membership_txn(txn, next_id):
+            # TODO: Upsert?
+            self._simple_delete_txn(
+                txn,
+                table="local_group_membership",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_insert_txn(
+                txn,
+                table="local_group_membership",
+                values={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "is_admin": is_admin,
+                    "membership": membership,
+                    "is_publicised": is_publicised,
+                    "content": json.dumps(content),
+                },
+            )
+
+            self._simple_insert_txn(
+                txn,
+                table="local_group_updates",
+                values={
+                    "stream_id": next_id,
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "type": "membership",
+                    "content": json.dumps({"membership": membership, "content": content}),
+                }
+            )
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
+            # TODO: Insert profile to ensure it comes down stream if its a join.
+
+            if membership == "join":
+                if local_attestation:
+                    self._simple_insert_txn(
+                        txn,
+                        table="group_attestations_renewals",
+                        values={
+                            "group_id": group_id,
+                            "user_id": user_id,
+                            "valid_until_ms": local_attestation["valid_until_ms"],
+                        }
+                    )
+                if remote_attestation:
+                    self._simple_insert_txn(
+                        txn,
+                        table="group_attestations_remote",
+                        values={
+                            "group_id": group_id,
+                            "user_id": user_id,
+                            "valid_until_ms": remote_attestation["valid_until_ms"],
+                            "attestation_json": json.dumps(remote_attestation),
+                        }
+                    )
+            else:
+                self._simple_delete_txn(
+                    txn,
+                    table="group_attestations_renewals",
+                    keyvalues={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                    },
+                )
+                self._simple_delete_txn(
+                    txn,
+                    table="group_attestations_remote",
+                    keyvalues={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                    },
+                )
+
+            return next_id
+
+        with self._group_updates_id_gen.get_next() as next_id:
+            res = yield self.runInteraction(
+                "register_user_group_membership",
+                _register_user_group_membership_txn, next_id,
+            )
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def create_group(self, group_id, user_id, name, avatar_url, short_description,
+                     long_description,):
+        yield self._simple_insert(
+            table="groups",
+            values={
+                "group_id": group_id,
+                "name": name,
+                "avatar_url": avatar_url,
+                "short_description": short_description,
+                "long_description": long_description,
+                "is_public": True,
+            },
+            desc="create_group",
+        )
+
+    @defer.inlineCallbacks
+    def update_group_profile(self, group_id, profile,):
+        yield self._simple_update_one(
+            table="groups",
+            keyvalues={
+                "group_id": group_id,
+            },
+            updatevalues=profile,
+            desc="update_group_profile",
+        )
+
+    def get_attestations_need_renewals(self, valid_until_ms):
+        """Get all attestations that need to be renewed until givent time
+        """
+        def _get_attestations_need_renewals_txn(txn):
+            sql = """
+                SELECT group_id, user_id FROM group_attestations_renewals
+                WHERE valid_until_ms <= ?
+            """
+            txn.execute(sql, (valid_until_ms,))
+            return self.cursor_to_dict(txn)
+        return self.runInteraction(
+            "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+        )
+
+    def update_attestation_renewal(self, group_id, user_id, attestation):
+        """Update an attestation that we have renewed
+        """
+        return self._simple_update_one(
+            table="group_attestations_renewals",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            updatevalues={
+                "valid_until_ms": attestation["valid_until_ms"],
+            },
+            desc="update_attestation_renewal",
+        )
+
+    def update_remote_attestion(self, group_id, user_id, attestation):
+        """Update an attestation that a remote has renewed
+        """
+        return self._simple_update_one(
+            table="group_attestations_remote",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            updatevalues={
+                "valid_until_ms": attestation["valid_until_ms"],
+                "attestation_json": json.dumps(attestation)
+            },
+            desc="update_remote_attestion",
+        )
+
+    def remove_attestation_renewal(self, group_id, user_id):
+        """Remove an attestation that we thought we should renew, but actually
+        shouldn't. Ideally this would never get called as we would never
+        incorrectly try and do attestations for local users on local groups.
+
+        Args:
+            group_id (str)
+            user_id (str)
+        """
+        return self._simple_delete(
+            table="group_attestations_renewals",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            desc="remove_attestation_renewal",
+        )
+
+    @defer.inlineCallbacks
+    def get_remote_attestation(self, group_id, user_id):
+        """Get the attestation that proves the remote agrees that the user is
+        in the group.
+        """
+        row = yield self._simple_select_one(
+            table="group_attestations_remote",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcols=("valid_until_ms", "attestation_json"),
+            desc="get_remote_attestation",
+            allow_none=True,
+        )
+
+        now = int(self._clock.time_msec())
+        if row and now < row["valid_until_ms"]:
+            defer.returnValue(json.loads(row["attestation_json"]))
+
+        defer.returnValue(None)
+
+    def get_joined_groups(self, user_id):
+        return self._simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={
+                "user_id": user_id,
+                "membership": "join",
+            },
+            retcol="group_id",
+            desc="get_joined_groups",
+        )
+
+    def get_all_groups_for_user(self, user_id, now_token):
+        def _get_all_groups_for_user_txn(txn):
+            sql = """
+                SELECT group_id, type, membership, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND membership != 'leave'
+                    AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, now_token,))
+            return [
+                {
+                    "group_id": row[0],
+                    "type": row[1],
+                    "membership": row[2],
+                    "content": json.loads(row[3]),
+                }
+                for row in txn
+            ]
+        return self.runInteraction(
+            "get_all_groups_for_user", _get_all_groups_for_user_txn,
+        )
+
+    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_entity_changed(
+            user_id, from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_groups_changes_for_user_txn(txn):
+            sql = """
+                SELECT group_id, membership, type, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, from_token, to_token,))
+            return [{
+                "group_id": group_id,
+                "membership": membership,
+                "type": gtype,
+                "content": json.loads(content_json),
+            } for group_id, membership, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+        )
+
+    def get_all_groups_changes(self, from_token, to_token, limit):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+            from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_all_groups_changes_txn(txn):
+            sql = """
+                SELECT stream_id, group_id, user_id, type, content
+                FROM local_group_updates
+                WHERE ? < stream_id AND stream_id <= ?
+                LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit,))
+            return [(
+                stream_id,
+                group_id,
+                user_id,
+                gtype,
+                json.loads(content_json),
+            ) for stream_id, group_id, user_id, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_all_groups_changes", _get_all_groups_changes_txn,
+        )
+
+    def get_group_stream_token(self):
+        return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3b5e0a4fb9..f547977600 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,19 +13,29 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+import hashlib
+import logging
 
-from twisted.internet import defer
+import six
 
-import OpenSSL
 from signedjson.key import decode_verify_key_bytes
-import hashlib
 
-import logging
+import OpenSSL
+from twisted.internet import defer
+
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+    db_binary_type = buffer
+else:
+    db_binary_type = memoryview
+
 
 class KeyStore(SQLBaseStore):
     """Persistence for signature verification keys and tls X.509 certificates
@@ -72,7 +82,7 @@ class KeyStore(SQLBaseStore):
             values={
                 "from_server": from_server,
                 "ts_added_ms": time_now_ms,
-                "tls_certificate": buffer(tls_certificate_bytes),
+                "tls_certificate": db_binary_type(tls_certificate_bytes),
             },
             desc="store_server_certificate",
         )
@@ -92,7 +102,7 @@ class KeyStore(SQLBaseStore):
 
         if verify_key_bytes:
             defer.returnValue(decode_verify_key_bytes(
-                key_id, str(verify_key_bytes)
+                key_id, bytes(verify_key_bytes)
             ))
 
     @defer.inlineCallbacks
@@ -113,30 +123,37 @@ class KeyStore(SQLBaseStore):
                 keys[key_id] = key
         defer.returnValue(keys)
 
-    @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
                                 verify_key):
         """Stores a NACL verification key for the given server.
         Args:
             server_name (str): The name of the server.
-            key_id (str): The version of the key for the server.
             from_server (str): Where the verification key was looked up
-            ts_now_ms (int): The time now in milliseconds
-            verification_key (VerifyKey): The NACL verify key.
+            time_now_ms (int): The time now in milliseconds
+            verify_key (nacl.signing.VerifyKey): The NACL verify key.
         """
-        yield self._simple_upsert(
-            table="server_signature_keys",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
-            },
-            values={
-                "from_server": from_server,
-                "ts_added_ms": time_now_ms,
-                "verify_key": buffer(verify_key.encode()),
-            },
-            desc="store_server_verify_key",
-        )
+        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+        def _txn(txn):
+            self._simple_upsert_txn(
+                txn,
+                table="server_signature_keys",
+                keyvalues={
+                    "server_name": server_name,
+                    "key_id": key_id,
+                },
+                values={
+                    "from_server": from_server,
+                    "ts_added_ms": time_now_ms,
+                    "verify_key": db_binary_type(verify_key.encode()),
+                },
+            )
+            txn.call_after(
+                self._get_server_verify_key.invalidate,
+                (server_name, key_id)
+            )
+
+        return self.runInteraction("store_server_verify_key", _txn)
 
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
@@ -165,7 +182,7 @@ class KeyStore(SQLBaseStore):
                 "from_server": from_server,
                 "ts_added_ms": ts_now_ms,
                 "ts_valid_until_ms": ts_expires_ms,
-                "key_json": buffer(key_json_bytes),
+                "key_json": db_binary_type(key_json_bytes),
             },
             desc="store_server_keys_json",
         )
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 82bb61b811..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -12,15 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from synapse.storage.background_updates import BackgroundUpdateStore
 
-from ._base import SQLBaseStore
 
-
-class MediaRepositoryStore(SQLBaseStore):
+class MediaRepositoryStore(BackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def get_default_thumbnails(self, top_level_type, sub_type):
-        return []
+    def __init__(self, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            update_name='local_media_repository_url_idx',
+            index_name='local_media_repository_url_idx',
+            table='local_media_repository',
+            columns=['created_ts'],
+            where_clause='url_cache IS NOT NULL',
+        )
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
@@ -62,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore):
         def get_url_cache_txn(txn):
             # get the most recently cached result (relative to the given ts)
             sql = (
-                "SELECT response_code, etag, expires, og, media_id, download_ts"
+                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                 " FROM local_media_repository_url_cache"
                 " WHERE url = ? AND download_ts <= ?"
                 " ORDER BY download_ts DESC LIMIT 1"
@@ -74,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
                 sql = (
-                    "SELECT response_code, etag, expires, og, media_id, download_ts"
+                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                     " FROM local_media_repository_url_cache"
                     " WHERE url = ? AND download_ts > ?"
                     " ORDER BY download_ts ASC LIMIT 1"
@@ -86,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore):
                 return None
 
             return dict(zip((
-                'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+                'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
             ), row))
 
         return self.runInteraction(
             "get_url_cache", get_url_cache_txn
         )
 
-    def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+    def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
                         download_ts):
         return self._simple_insert(
             "local_media_repository_url_cache",
@@ -101,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 "url": url,
                 "response_code": response_code,
                 "etag": etag,
-                "expires": expires,
+                "expires_ts": expires_ts,
                 "og": og,
                 "media_id": media_id,
                 "download_ts": download_ts,
@@ -166,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+        """Updates the last access time of the given media
+
+        Args:
+            local_media (iterable[str]): Set of media_ids
+            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            time_ms: Current time in milliseconds
+        """
         def update_cache_txn(txn):
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
@@ -174,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore):
             )
 
             txn.executemany(sql, (
-                (time_ts, media_origin, media_id)
-                for media_origin, media_id in origin_id_tuples
+                (time_ms, media_origin, media_id)
+                for media_origin, media_id in remote_media
+            ))
+
+            sql = (
+                "UPDATE local_media_repository SET last_access_ts = ?"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, (
+                (time_ms, media_id)
+                for media_id in local_media
             ))
 
         return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@@ -238,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore):
                 },
             )
         return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+    def get_expired_url_cache(self, now_ts):
+        sql = (
+            "SELECT media_id FROM local_media_repository_url_cache"
+            " WHERE expires_ts < ?"
+            " ORDER BY expires_ts ASC"
+            " LIMIT 500"
+        )
+
+        def _get_expired_url_cache_txn(txn):
+            txn.execute(sql, (now_ts,))
+            return [row[0] for row in txn]
+
+        return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+    def delete_url_cache(self, media_ids):
+        if len(media_ids) == 0:
+            return
+
+        sql = (
+            "DELETE FROM local_media_repository_url_cache"
+            " WHERE media_id = ?"
+        )
+
+        def _delete_url_cache_txn(txn):
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+        return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+    def get_url_cache_media_before(self, before_ts):
+        sql = (
+            "SELECT media_id FROM local_media_repository"
+            " WHERE created_ts < ? AND url_cache IS NOT NULL"
+            " ORDER BY created_ts ASC"
+            " LIMIT 500"
+        )
+
+        def _get_url_cache_media_before_txn(txn):
+            txn.execute(sql, (before_ts,))
+            return [row[0] for row in txn]
+
+        return self.runInteraction(
+            "get_url_cache_media_before", _get_url_cache_media_before_txn,
+        )
+
+    def delete_url_cache_media(self, media_ids):
+        if len(media_ids) == 0:
+            return
+
+        def _delete_url_cache_media_txn(txn):
+            sql = (
+                "DELETE FROM local_media_repository"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+            sql = (
+                "DELETE FROM local_media_repository_thumbnails"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+        return self.runInteraction(
+            "delete_url_cache_media", _delete_url_cache_media_txn,
+        )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 72b670b83b..b290f834b3 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,13 +20,12 @@ import logging
 import os
 import re
 
-
 logger = logging.getLogger(__name__)
 
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 43
+SCHEMA_VERSION = 50
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
 
     If `config` is None then prepare_database will assert that no upgrade is
     necessary, *or* will create a fresh database if the database is empty.
+
+    Args:
+        db_conn:
+        database_engine:
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
     """
     try:
         cur = db_conn.cursor()
@@ -64,9 +71,13 @@ def prepare_database(db_conn, database_engine, config):
         else:
             _setup_new_database(cur, database_engine)
 
+        # check if any of our configured dynamic modules want a database
+        if config is not None:
+            _apply_module_schemas(cur, database_engine, config)
+
         cur.close()
         db_conn.commit()
-    except:
+    except Exception:
         db_conn.rollback()
         raise
 
@@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
             )
 
 
+def _apply_module_schemas(txn, database_engine, config):
+    """Apply the module schemas for the dynamic modules, if any
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        config (synapse.config.homeserver.HomeServerConfig):
+            application config
+    """
+    for (mod, _config) in config.password_providers:
+        if not hasattr(mod, 'get_db_schema_files'):
+            continue
+        modname = ".".join((mod.__module__, mod.__name__))
+        _apply_module_schema_files(
+            txn, database_engine, modname, mod.get_db_schema_files(),
+        )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+    """Apply the module schemas for a single module
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        modname (str): fully qualified name of the module
+        names_and_streams (Iterable[(str, file)]): the names and streams of
+            schemas to be applied
+    """
+    cur.execute(
+        database_engine.convert_param_style(
+            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+        ),
+        (modname,)
+    )
+    applied_deltas = set(d for d, in cur)
+    for (name, stream) in names_and_streams:
+        if name in applied_deltas:
+            continue
+
+        root_name, ext = os.path.splitext(name)
+        if ext != '.sql':
+            raise PrepareDatabaseException(
+                "only .sql files are currently supported for module schemas",
+            )
+
+        logger.info("applying schema %s for %s", name, modname)
+        for statement in get_statements(stream):
+            cur.execute(statement)
+
+        # Mark as done.
+        cur.execute(
+            database_engine.convert_param_style(
+                "INSERT INTO applied_module_schemas (module_name, file)"
+                " VALUES (?,?)",
+            ),
+            (modname, name)
+        )
+
+
 def get_statements(f):
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 9e9d3c2591..a0c7a0dc87 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -13,12 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from collections import namedtuple
+
+from twisted.internet import defer
+
 from synapse.api.constants import PresenceState
+from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
-from collections import namedtuple
-from twisted.internet import defer
+from ._base import SQLBaseStore
 
 
 class UserPresenceState(namedtuple("UserPresenceState",
@@ -115,11 +118,7 @@ class PresenceStore(SQLBaseStore):
             " AND user_id IN (%s)"
         )
 
-        batches = (
-            presence_states[i:i + 50]
-            for i in xrange(0, len(presence_states), 50)
-        )
-        for states in batches:
+        for states in batch_iter(presence_states, 50):
             args = [stream_id]
             args.extend(s.user_id for s in states)
             txn.execute(
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..60295da254 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,15 +13,37 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.storage.roommember import ProfileInfo
+
 from ._base import SQLBaseStore
 
 
-class ProfileStore(SQLBaseStore):
-    def create_profile(self, user_localpart):
-        return self._simple_insert(
-            table="profiles",
-            values={"user_id": user_localpart},
-            desc="create_profile",
+class ProfileWorkerStore(SQLBaseStore):
+    @defer.inlineCallbacks
+    def get_profileinfo(self, user_localpart):
+        try:
+            profile = yield self._simple_select_one(
+                table="profiles",
+                keyvalues={"user_id": user_localpart},
+                retcols=("displayname", "avatar_url"),
+                desc="get_profileinfo",
+            )
+        except StoreError as e:
+            if e.code == 404:
+                # no match
+                defer.returnValue(ProfileInfo(None, None))
+                return
+            else:
+                raise
+
+        defer.returnValue(
+            ProfileInfo(
+                avatar_url=profile['avatar_url'],
+                display_name=profile['displayname'],
+            )
         )
 
     def get_profile_displayname(self, user_localpart):
@@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore):
             desc="get_profile_displayname",
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"displayname": new_displayname},
-            desc="set_profile_displayname",
-        )
-
     def get_profile_avatar_url(self, user_localpart):
         return self._simple_select_one_onecol(
             table="profiles",
@@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore):
             desc="get_profile_avatar_url",
         )
 
+    def get_from_remote_profile_cache(self, user_id):
+        return self._simple_select_one(
+            table="remote_profile_cache",
+            keyvalues={"user_id": user_id},
+            retcols=("displayname", "avatar_url",),
+            allow_none=True,
+            desc="get_from_remote_profile_cache",
+        )
+
+
+class ProfileStore(ProfileWorkerStore):
+    def create_profile(self, user_localpart):
+        return self._simple_insert(
+            table="profiles",
+            values={"user_id": user_localpart},
+            desc="create_profile",
+        )
+
+    def set_profile_displayname(self, user_localpart, new_displayname):
+        return self._simple_update_one(
+            table="profiles",
+            keyvalues={"user_id": user_localpart},
+            updatevalues={"displayname": new_displayname},
+            desc="set_profile_displayname",
+        )
+
     def set_profile_avatar_url(self, user_localpart, new_avatar_url):
         return self._simple_update_one(
             table="profiles",
@@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore):
             updatevalues={"avatar_url": new_avatar_url},
             desc="set_profile_avatar_url",
         )
+
+    def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+        """Ensure we are caching the remote user's profiles.
+
+        This should only be called when `is_subscribed_remote_profile_for_user`
+        would return true for the user.
+        """
+        return self._simple_upsert(
+            table="remote_profile_cache",
+            keyvalues={"user_id": user_id},
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="add_remote_profile_cache",
+        )
+
+    def update_remote_profile_cache(self, user_id, displayname, avatar_url):
+        return self._simple_update(
+            table="remote_profile_cache",
+            keyvalues={"user_id": user_id},
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="update_remote_profile_cache",
+        )
+
+    @defer.inlineCallbacks
+    def maybe_delete_remote_profile_cache(self, user_id):
+        """Check if we still care about the remote user's profile, and if we
+        don't then remove their profile from the cache
+        """
+        subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+        if not subscribed:
+            yield self._simple_delete(
+                table="remote_profile_cache",
+                keyvalues={"user_id": user_id},
+                desc="delete_remote_profile_cache",
+            )
+
+    def get_remote_profile_cache_entries_that_expire(self, last_checked):
+        """Get all users who haven't been checked since `last_checked`
+        """
+        def _get_remote_profile_cache_entries_that_expire_txn(txn):
+            sql = """
+                SELECT user_id, displayname, avatar_url
+                FROM remote_profile_cache
+                WHERE last_check < ?
+            """
+
+            txn.execute(sql, (last_checked,))
+
+            return self.cursor_to_dict(txn)
+
+        return self.runInteraction(
+            "get_remote_profile_cache_entries_that_expire",
+            _get_remote_profile_cache_entries_that_expire_txn,
+        )
+
+    @defer.inlineCallbacks
+    def is_subscribed_remote_profile_for_user(self, user_id):
+        """Check whether we are interested in a remote user's profile.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            defer.returnValue(True)
+
+        res = yield self._simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            defer.returnValue(True)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,14 +14,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes
+import abc
+import logging
+
+from canonicaljson import json
+
 from twisted.internet import defer
 
-import logging
-import simplejson as json
+from synapse.push.baserules import list_with_base_rules
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -48,7 +57,43 @@ def _load_rules(rawrules, enabled_map):
     return rules
 
 
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+                           ReceiptsWorkerStore,
+                           PusherWorkerStore,
+                           RoomMemberWorkerStore,
+                           SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
@@ -89,6 +134,22 @@ class PushRuleStore(SQLBaseStore):
             r['rule_id']: False if r['enabled'] == 0 else True for r in results
         })
 
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
     @cachedList(cached_method_name="get_push_rules_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules(self, user_ids):
@@ -124,6 +185,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    @defer.inlineCallbacks
     def bulk_get_push_rules_for_room(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -133,9 +195,11 @@ class PushRuleStore(SQLBaseStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._bulk_get_push_rules_for_room(
-            event.room_id, state_group, context.current_state_ids, event=event
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, current_state_ids, event=event
         )
+        defer.returnValue(result)
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
@@ -185,18 +249,6 @@ class PushRuleStore(SQLBaseStore):
             if uid in local_users_in_room:
                 user_ids.add(uid)
 
-        forgotten = yield self.who_forgot_in_room(
-            event.room_id, on_invalidate=cache_context.invalidate,
-        )
-
-        for row in forgotten:
-            user_id = row["user_id"]
-            event_id = row["event_id"]
-
-            mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
-            if event_id == mem_id:
-                user_ids.discard(user_id)
-
         rules_by_user = yield self.bulk_get_push_rules(
             user_ids, on_invalidate=cache_context.invalidate,
         )
@@ -228,6 +280,8 @@ class PushRuleStore(SQLBaseStore):
             results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
         defer.returnValue(results)
 
+
+class PushRuleStore(PushRulesWorkerStore):
     @defer.inlineCallbacks
     def add_push_rule(
         self, user_id, rule_id, priority_class, conditions, actions,
@@ -526,21 +580,8 @@ class PushRuleStore(SQLBaseStore):
         room stream ordering it corresponds to."""
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
 
 
 class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 34d2f82b7f..8443bd4c1b 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,21 +14,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from twisted.internet import defer
+import logging
+import types
+
+from canonicaljson import encode_canonical_json, json
 
-from canonicaljson import encode_canonical_json
+from twisted.internet import defer
 
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 
-import logging
-import simplejson as json
-import types
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
 
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
     def _decode_pushers_rows(self, rows):
         for r in rows:
             dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
-    def get_pushers_stream_token(self):
-        return self._pushers_id_gen.get_current_token()
-
     def get_all_updated_pushers(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed(([], []))
@@ -198,56 +196,74 @@ class PusherStore(SQLBaseStore):
 
         defer.returnValue(result)
 
+
+class PusherStore(PusherWorkerStore):
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_pusher(self, user_id, access_token, kind, app_id,
                    app_display_name, device_display_name,
                    pushkey, pushkey_ts, lang, data, last_stream_ordering,
                    profile_tag=""):
         with self._pushers_id_gen.get_next() as stream_id:
-            def f(txn):
-                newly_inserted = self._simple_upsert_txn(
-                    txn,
-                    "pushers",
-                    {
-                        "app_id": app_id,
-                        "pushkey": pushkey,
-                        "user_name": user_id,
-                    },
-                    {
-                        "access_token": access_token,
-                        "kind": kind,
-                        "app_display_name": app_display_name,
-                        "device_display_name": device_display_name,
-                        "ts": pushkey_ts,
-                        "lang": lang,
-                        "data": encode_canonical_json(data),
-                        "last_stream_ordering": last_stream_ordering,
-                        "profile_tag": profile_tag,
-                        "id": stream_id,
-                    },
-                )
-                if newly_inserted:
-                    # get_if_user_has_pusher only cares if the user has
-                    # at least *one* pusher.
-                    txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+            # no need to lock because `pushers` has a unique key on
+            # (app_id, pushkey, user_name) so _simple_upsert will retry
+            newly_inserted = yield self._simple_upsert(
+                table="pushers",
+                keyvalues={
+                    "app_id": app_id,
+                    "pushkey": pushkey,
+                    "user_name": user_id,
+                },
+                values={
+                    "access_token": access_token,
+                    "kind": kind,
+                    "app_display_name": app_display_name,
+                    "device_display_name": device_display_name,
+                    "ts": pushkey_ts,
+                    "lang": lang,
+                    "data": encode_canonical_json(data),
+                    "last_stream_ordering": last_stream_ordering,
+                    "profile_tag": profile_tag,
+                    "id": stream_id,
+                },
+                desc="add_pusher",
+                lock=False,
+            )
 
-            yield self.runInteraction("add_pusher", f)
+            if newly_inserted:
+                yield self.runInteraction(
+                    "add_pusher",
+                    self._invalidate_cache_and_stream,
+                    self.get_if_user_has_pusher, (user_id,)
+                )
 
     @defer.inlineCallbacks
     def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
         def delete_pusher_txn(txn, stream_id):
-            txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_if_user_has_pusher, (user_id,)
+            )
 
             self._simple_delete_one_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
             )
-            self._simple_upsert_txn(
+
+            # it's possible for us to end up with duplicate rows for
+            # (app_id, pushkey, user_id) at different stream_ids, but that
+            # doesn't really matter.
+            self._simple_insert_txn(
                 txn,
-                "deleted_pushers",
-                {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
-                {"stream_id": stream_id},
+                table="deleted_pushers",
+                values={
+                    "stream_id": stream_id,
+                    "app_id": app_id,
+                    "pushkey": pushkey,
+                    "user_id": user_id,
+                },
             )
 
         with self._pushers_id_gen.get_next() as stream_id:
@@ -310,9 +326,12 @@ class PusherStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def set_throttle_params(self, pusher_id, room_id, params):
+        # no need to lock because `pusher_throttle` has a primary key on
+        # (pusher, room_id) so _simple_upsert will retry
         yield self._simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
             params,
-            desc="set_throttle_params"
+            desc="set_throttle_params",
+            lock=False,
         )
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f42b8014c7..0ac665e967 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,52 +14,52 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+import abc
+import logging
+
+from canonicaljson import json
 
 from twisted.internet import defer
 
-import logging
-import ujson as json
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
 
 logger = logging.getLogger(__name__)
 
 
-class ReceiptsStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(ReceiptsStore, self).__init__(hs)
+class ReceiptsWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_receipt_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
+    @abc.abstractmethod
+    def get_max_receipt_stream_id(self):
+        """Get the current max stream ID for receipts stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
         defer.returnValue(set(r['user_id'] for r in receipts))
 
-    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
-                                                    user_id):
-        if receipt_type != "m.read":
-            return
-
-        # Returns an ObservableDeferred
-        res = self.get_users_with_read_receipts_in_room.cache.get(
-            room_id, None, update_metrics=False,
-        )
-
-        if res:
-            if isinstance(res, defer.Deferred) and res.called:
-                res = res.result
-            if user_id in res:
-                # We'd only be adding to the set, so no point invalidating if the
-                # user is already there
-                return
-
-        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
         return self._simple_select_list(
@@ -139,7 +140,9 @@ class ReceiptsStore(SQLBaseStore):
         """
         room_ids = set(room_ids)
 
-        if from_key:
+        if from_key is not None:
+            # Only ask the database about rooms where there have been new
+            # receipts added since `from_key`
             room_ids = yield self._receipts_stream_cache.get_entities_changed(
                 room_ids, from_key
             )
@@ -150,7 +153,6 @@ class ReceiptsStore(SQLBaseStore):
 
         defer.returnValue([ev for res in results.values() for ev in res])
 
-    @cachedInlineCallbacks(num_args=3, tree=True)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
@@ -161,7 +163,19 @@ class ReceiptsStore(SQLBaseStore):
                 from the start.
 
         Returns:
-            list: A list of receipts.
+            Deferred[list]: A list of receipts.
+        """
+        if from_key is not None:
+            # Check the cache first to see if any new receipts have been added
+            # since`from_key`. If not we can no-op.
+            if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+                defer.succeed([])
+
+        return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+
+    @cachedInlineCallbacks(num_args=3, tree=True)
+    def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+        """See get_linearized_receipts_for_room
         """
         def f(txn):
             if from_key:
@@ -210,7 +224,7 @@ class ReceiptsStore(SQLBaseStore):
             "content": content,
         }])
 
-    @cachedList(cached_method_name="get_linearized_receipts_for_room",
+    @cachedList(cached_method_name="_get_linearized_receipts_for_room",
                 list_name="room_ids", num_args=3, inlineCallbacks=True)
     def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
@@ -270,11 +284,97 @@ class ReceiptsStore(SQLBaseStore):
         }
         defer.returnValue(results)
 
+    def get_all_updated_receipts(self, last_id, current_id, limit=None):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_receipts_txn(txn):
+            sql = (
+                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+                " FROM receipts_linearized"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+            )
+            args = [last_id, current_id]
+            if limit is not None:
+                sql += " LIMIT ?"
+                args.append(limit)
+            txn.execute(sql, args)
+
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_updated_receipts", get_all_updated_receipts_txn
+        )
+
+    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+                                                    user_id):
+        if receipt_type != "m.read":
+            return
+
+        # Returns either an ObservableDeferred or the raw result
+        res = self.get_users_with_read_receipts_in_room.cache.get(
+            room_id, None, update_metrics=False,
+        )
+
+        # first handle the Deferred case
+        if isinstance(res, defer.Deferred):
+            if res.called:
+                res = res.result
+            else:
+                res = None
+
+        if res and user_id in res:
+            # We'd only be adding to the set, so no point invalidating if the
+            # user is already there
+            return
+
+        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    def __init__(self, db_conn, hs):
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+
+        super(ReceiptsStore, self).__init__(db_conn, hs)
+
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
+        res = self._simple_select_one_txn(
+            txn,
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True
+        )
+
+        stream_ordering = int(res["stream_ordering"]) if res else None
+
+        # We don't want to clobber receipts for more recent events, so we
+        # have to compare orderings of existing receipts
+        if stream_ordering is not None:
+            sql = (
+                "SELECT stream_ordering, event_id FROM events"
+                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+            )
+            txn.execute(sql, (room_id, receipt_type, user_id))
+
+            for so, eid in txn:
+                if int(so) >= stream_ordering:
+                    logger.debug(
+                        "Ignoring new receipt for %s in favour of existing "
+                        "one for later event %s",
+                        event_id, eid,
+                    )
+                    return False
+
         txn.call_after(
             self.get_receipts_for_room.invalidate, (room_id, receipt_type)
         )
@@ -286,7 +386,7 @@ class ReceiptsStore(SQLBaseStore):
             self.get_receipts_for_user.invalidate, (user_id, receipt_type)
         )
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed,
@@ -298,34 +398,6 @@ class ReceiptsStore(SQLBaseStore):
             (user_id, room_id, receipt_type)
         )
 
-        res = self._simple_select_one_txn(
-            txn,
-            table="events",
-            retcols=["topological_ordering", "stream_ordering"],
-            keyvalues={"event_id": event_id},
-            allow_none=True
-        )
-
-        topological_ordering = int(res["topological_ordering"]) if res else None
-        stream_ordering = int(res["stream_ordering"]) if res else None
-
-        # We don't want to clobber receipts for more recent events, so we
-        # have to compare orderings of existing receipts
-        sql = (
-            "SELECT topological_ordering, stream_ordering, event_id FROM events"
-            " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
-            " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
-        )
-
-        txn.execute(sql, (room_id, receipt_type, user_id))
-
-        if topological_ordering:
-            for to, so, _ in txn:
-                if int(to) > topological_ordering:
-                    return False
-                elif int(to) == topological_ordering and int(so) >= stream_ordering:
-                    return False
-
         self._simple_delete_txn(
             txn,
             table="receipts_linearized",
@@ -349,12 +421,11 @@ class ReceiptsStore(SQLBaseStore):
             }
         )
 
-        if receipt_type == "m.read" and topological_ordering:
+        if receipt_type == "m.read" and stream_ordering is not None:
             self._remove_old_push_actions_before_txn(
                 txn,
                 room_id=room_id,
                 user_id=user_id,
-                topological_ordering=topological_ordering,
                 stream_ordering=stream_ordering,
             )
 
@@ -435,7 +506,7 @@ class ReceiptsStore(SQLBaseStore):
             self.get_receipts_for_user.invalidate, (user_id, receipt_type)
         )
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         self._simple_delete_txn(
             txn,
@@ -457,25 +528,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-    def get_all_updated_receipts(self, last_id, current_id, limit=None):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_receipts_txn(txn):
-            sql = (
-                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
-                " FROM receipts_linearized"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-            )
-            args = [last_id, current_id]
-            if limit is not None:
-                sql += " LIMIT ?"
-                args.append(limit)
-            txn.execute(sql, args)
-
-            return txn.fetchall()
-        return self.runInteraction(
-            "get_all_updated_receipts", get_all_updated_receipts_txn
-        )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 20acd58fcf..07333f777d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,17 +15,83 @@
 
 import re
 
+from six.moves import range
+
 from twisted.internet import defer
 
-from synapse.api.errors import StoreError, Codes
+from synapse.api.errors import Codes, StoreError
 from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 
-class RegistrationStore(background_updates.BackgroundUpdateStore):
+class RegistrationWorkerStore(SQLBaseStore):
+    @cached()
+    def get_user_by_id(self, user_id):
+        return self._simple_select_one(
+            table="users",
+            keyvalues={
+                "name": user_id,
+            },
+            retcols=[
+                "name", "password_hash", "is_guest",
+                "consent_version", "consent_server_notice_sent",
+                "appservice_id",
+            ],
+            allow_none=True,
+            desc="get_user_by_id",
+        )
+
+    @cached()
+    def get_user_by_access_token(self, token):
+        """Get a user from the given access token.
+
+        Args:
+            token (str): The access token of a user.
+        Returns:
+            defer.Deferred: None, if the token did not match, otherwise dict
+                including the keys `name`, `is_guest`, `device_id`, `token_id`.
+        """
+        return self.runInteraction(
+            "get_user_by_access_token",
+            self._query_for_auth,
+            token
+        )
+
+    @defer.inlineCallbacks
+    def is_server_admin(self, user):
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            retcol="admin",
+            allow_none=True,
+            desc="is_server_admin",
+        )
+
+        defer.returnValue(res if res else False)
+
+    def _query_for_auth(self, txn, token):
+        sql = (
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            " access_tokens.device_id"
+            " FROM users"
+            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
+            " WHERE token = ?"
+        )
+
+        txn.execute(sql, (token,))
+        rows = self.cursor_to_dict(txn)
+        if rows:
+            return rows[0]
+
+        return None
+
+
+class RegistrationStore(RegistrationWorkerStore,
+                        background_updates.BackgroundUpdateStore):
 
-    def __init__(self, hs):
-        super(RegistrationStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(RegistrationStore, self).__init__(db_conn, hs)
 
         self.clock = hs.get_clock()
 
@@ -37,12 +103,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         )
 
         self.register_background_index_update(
-            "refresh_tokens_device_index",
-            index_name="refresh_tokens_device_id",
-            table="refresh_tokens",
-            columns=["user_id", "device_id"],
+            "users_creation_ts",
+            index_name="users_creation_ts",
+            table="users",
+            columns=["creation_ts"],
         )
 
+        # we no longer use refresh tokens, but it's possible that some people
+        # might have a background update queued to build this index. Just
+        # clear the background update.
+        self.register_noop_background_update("refresh_tokens_device_index")
+
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
@@ -177,9 +248,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             )
 
         if create_profile_with_localpart:
+            # set a default displayname serverside to avoid ugly race
+            # between auto-joins and clients trying to set displaynames
             txn.execute(
-                "INSERT INTO profiles(user_id) VALUES (?)",
-                (create_profile_with_localpart,)
+                "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
+                (create_profile_with_localpart, create_profile_with_localpart)
             )
 
         self._invalidate_cache_and_stream(
@@ -187,18 +260,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    @cached()
-    def get_user_by_id(self, user_id):
-        return self._simple_select_one(
-            table="users",
-            keyvalues={
-                "name": user_id,
-            },
-            retcols=["name", "password_hash", "is_guest"],
-            allow_none=True,
-            desc="get_user_by_id",
-        )
-
     def get_users_by_id_case_insensitive(self, user_id):
         """Gets users that match user_id case insensitively.
         Returns a mapping of user_id -> password_hash.
@@ -236,12 +297,57 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             "user_set_password_hash", user_set_password_hash_txn
         )
 
-    @defer.inlineCallbacks
+    def user_set_consent_version(self, user_id, consent_version):
+        """Updates the user table to record privacy policy consent
+
+        Args:
+            user_id (str): full mxid of the user to update
+            consent_version (str): version of the policy the user has consented
+                to
+
+        Raises:
+            StoreError(404) if user not found
+        """
+        def f(txn):
+            self._simple_update_one_txn(
+                txn,
+                table='users',
+                keyvalues={'name': user_id, },
+                updatevalues={'consent_version': consent_version, },
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_id, (user_id,)
+            )
+        return self.runInteraction("user_set_consent_version", f)
+
+    def user_set_consent_server_notice_sent(self, user_id, consent_version):
+        """Updates the user table to record that we have sent the user a server
+        notice about privacy policy consent
+
+        Args:
+            user_id (str): full mxid of the user to update
+            consent_version (str): version of the policy we have notified the
+                user about
+
+        Raises:
+            StoreError(404) if user not found
+        """
+        def f(txn):
+            self._simple_update_one_txn(
+                txn,
+                table='users',
+                keyvalues={'name': user_id, },
+                updatevalues={'consent_server_notice_sent': consent_version, },
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_id, (user_id,)
+            )
+        return self.runInteraction("user_set_consent_server_notice_sent", f)
+
     def user_delete_access_tokens(self, user_id, except_token_id=None,
-                                  device_id=None,
-                                  delete_refresh_tokens=False):
+                                  device_id=None):
         """
-        Invalidate access/refresh tokens belonging to a user
+        Invalidate access tokens belonging to a user
 
         Args:
             user_id (str):  ID of user the tokens belong to
@@ -250,10 +356,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             device_id (str|None):  ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
                 be deleted
-            delete_refresh_tokens (bool):  True to delete refresh tokens as
-                well as access tokens.
         Returns:
-            defer.Deferred:
+            defer.Deferred[list[str, int, str|None, int]]: a list of
+                (token, token id, device id) for each of the deleted tokens
         """
         def f(txn):
             keyvalues = {
@@ -262,13 +367,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             if device_id is not None:
                 keyvalues["device_id"] = device_id
 
-            if delete_refresh_tokens:
-                self._simple_delete_txn(
-                    txn,
-                    table="refresh_tokens",
-                    keyvalues=keyvalues,
-                )
-
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
             values = [v for _, v in items]
@@ -277,14 +375,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 values.append(except_token_id)
 
             txn.execute(
-                "SELECT token FROM access_tokens WHERE %s" % where_clause,
+                "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
                 values
             )
-            rows = self.cursor_to_dict(txn)
+            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
-            for row in rows:
+            for token, _, _ in tokens_and_devices:
                 self._invalidate_cache_and_stream(
-                    txn, self.get_user_by_access_token, (row["token"],)
+                    txn, self.get_user_by_access_token, (token,)
                 )
 
             txn.execute(
@@ -292,7 +390,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 values
             )
 
-        yield self.runInteraction(
+            return tokens_and_devices
+
+        return self.runInteraction(
             "user_delete_access_tokens", f,
         )
 
@@ -312,34 +412,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         return self.runInteraction("delete_access_token", f)
 
-    @cached()
-    def get_user_by_access_token(self, token):
-        """Get a user from the given access token.
-
-        Args:
-            token (str): The access token of a user.
-        Returns:
-            defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`.
-        """
-        return self.runInteraction(
-            "get_user_by_access_token",
-            self._query_for_auth,
-            token
-        )
-
-    @defer.inlineCallbacks
-    def is_server_admin(self, user):
-        res = yield self._simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user.to_string()},
-            retcol="admin",
-            allow_none=True,
-            desc="is_server_admin",
-        )
-
-        defer.returnValue(res if res else False)
-
     @cachedInlineCallbacks()
     def is_guest(self, user_id):
         res = yield self._simple_select_one_onecol(
@@ -352,22 +424,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         defer.returnValue(res if res else False)
 
-    def _query_for_auth(self, txn, token):
-        sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
-            " access_tokens.device_id"
-            " FROM users"
-            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
-            " WHERE token = ?"
-        )
-
-        txn.execute(sql, (token,))
-        rows = self.cursor_to_dict(txn)
-        if rows:
-            return rows[0]
-
-        return None
-
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
         yield self._simple_upsert("user_threepids", {
@@ -404,15 +460,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             defer.returnValue(ret['user_id'])
         defer.returnValue(None)
 
-    def user_delete_threepids(self, user_id):
-        return self._simple_delete(
-            "user_threepids",
-            keyvalues={
-                "user_id": user_id,
-            },
-            desc="user_delete_threepids",
-        )
-
     def user_delete_threepid(self, user_id, medium, address):
         return self._simple_delete(
             "user_threepids",
@@ -437,6 +484,35 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         ret = yield self.runInteraction("count_users", _count_users)
         defer.returnValue(ret)
 
+    def count_daily_user_type(self):
+        """
+        Counts 1) native non guest users
+               2) native guests users
+               3) bridged users
+        who registered on the homeserver in the past 24 hours
+        """
+        def _count_daily_user_type(txn):
+            yesterday = int(self._clock.time()) - (60 * 60 * 24)
+
+            sql = """
+                SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+                    SELECT
+                    CASE
+                        WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
+                        WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
+                        WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
+                    END AS user_type
+                    FROM users
+                    WHERE creation_ts > ?
+                ) AS t GROUP BY user_type
+            """
+            results = {'native': 0, 'guest': 0, 'bridged': 0}
+            txn.execute(sql, (yesterday,))
+            for row in txn:
+                results[row[0]] = row[1]
+            return results
+        return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+
     @defer.inlineCallbacks
     def count_nonbridged_users(self):
         def _count_users(txn):
@@ -464,18 +540,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         """
         def _find_next_generated_user_id(txn):
             txn.execute("SELECT name FROM users")
-            rows = self.cursor_to_dict(txn)
 
             regex = re.compile("^@(\d+):")
 
             found = set()
 
-            for r in rows:
-                user_id = r["name"]
+            for user_id, in txn:
                 match = regex.search(user_id)
                 if match:
                     found.add(int(match.group(1)))
-            for i in xrange(len(found) + 1):
+            for i in range(len(found) + 1):
                 if i not in found:
                     return i
 
@@ -530,3 +604,44 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         except self.database_engine.module.IntegrityError:
             ret = yield self.get_3pid_guest_access_token(medium, address)
             defer.returnValue(ret)
+
+    def add_user_pending_deactivation(self, user_id):
+        """
+        Adds a user to the table of users who need to be parted from all the rooms they're
+        in
+        """
+        return self._simple_insert(
+            "users_pending_deactivation",
+            values={
+                "user_id": user_id,
+            },
+            desc="add_user_pending_deactivation",
+        )
+
+    def del_user_pending_deactivation(self, user_id):
+        """
+        Removes the given user to the table of users who need to be parted from all the
+        rooms they're in, effectively marking that user as fully deactivated.
+        """
+        # XXX: This should be simple_delete_one but we failed to put a unique index on
+        # the table, so somehow duplicate entries have ended up in it.
+        return self._simple_delete(
+            "users_pending_deactivation",
+            keyvalues={
+                "user_id": user_id,
+            },
+            desc="del_user_pending_deactivation",
+        )
+
+    def get_user_pending_deactivation(self):
+        """
+        Gets one user from the table of users waiting to be parted from all the rooms
+        they're in.
+        """
+        return self._simple_select_one_onecol(
+            "users_pending_deactivation",
+            keyvalues={},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_users_pending_deactivation",
+        )
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 40acb5c4ed..880f047adb 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -13,10 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-
 import logging
 
+from ._base import SQLBaseStore
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..3147fb6827 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -13,19 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
+import logging
+import re
+
+from canonicaljson import json
+
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.search import SearchStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
-
-import collections
-import logging
-import ujson as json
-import re
-
 logger = logging.getLogger(__name__)
 
 
@@ -40,7 +40,138 @@ RatelimitOverride = collections.namedtuple(
 )
 
 
-class RoomStore(SQLBaseStore):
+class RoomWorkerStore(SQLBaseStore):
+    def get_public_room_ids(self):
+        return self._simple_select_onecol(
+            table="rooms",
+            keyvalues={
+                "is_public": True,
+            },
+            retcol="room_id",
+            desc="get_public_room_ids",
+        )
+
+    @cached(num_args=2, max_entries=100)
+    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+        """Get pulbic rooms for a particular list, or across all lists.
+
+        Args:
+            stream_id (int)
+            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+                means the main list, None means all lsits.
+        """
+        return self.runInteraction(
+            "get_public_room_ids_at_stream_id",
+            self.get_public_room_ids_at_stream_id_txn,
+            stream_id, network_tuple=network_tuple
+        )
+
+    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+                                             network_tuple):
+        return {
+            rm
+            for rm, vis in self.get_published_at_stream_id_txn(
+                txn, stream_id, network_tuple=network_tuple
+            ).items()
+            if vis
+        }
+
+    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+        if network_tuple:
+            # We want to get from a particular list. No aggregation required.
+
+            sql = ("""
+                SELECT room_id, visibility FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT room_id, max(stream_id) AS stream_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ? %s
+                    GROUP BY room_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            if network_tuple.appservice_id is not None:
+                txn.execute(
+                    sql % ("AND appservice_id = ? AND network_id = ?",),
+                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+                )
+            else:
+                txn.execute(
+                    sql % ("AND appservice_id IS NULL",),
+                    (stream_id,)
+                )
+            return dict(txn)
+        else:
+            # We want to get from all lists, so we need to aggregate the results
+
+            logger.info("Executing full list")
+
+            sql = ("""
+                SELECT room_id, visibility
+                FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT
+                        room_id, max(stream_id) AS stream_id, appservice_id,
+                        network_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ?
+                    GROUP BY room_id, appservice_id, network_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            txn.execute(
+                sql,
+                (stream_id,)
+            )
+
+            results = {}
+            # A room is visible if its visible on any list.
+            for room_id, visibility in txn:
+                results[room_id] = bool(visibility) or results.get(room_id, False)
+
+            return results
+
+    def get_public_room_changes(self, prev_stream_id, new_stream_id,
+                                network_tuple):
+        def get_public_room_changes_txn(txn):
+            then_rooms = self.get_public_room_ids_at_stream_id_txn(
+                txn, prev_stream_id, network_tuple
+            )
+
+            now_rooms_dict = self.get_published_at_stream_id_txn(
+                txn, new_stream_id, network_tuple
+            )
+
+            now_rooms_visible = set(
+                rm for rm, vis in now_rooms_dict.items() if vis
+            )
+            now_rooms_not_visible = set(
+                rm for rm, vis in now_rooms_dict.items() if not vis
+            )
+
+            newly_visible = now_rooms_visible - then_rooms
+            newly_unpublished = now_rooms_not_visible & then_rooms
+
+            return newly_visible, newly_unpublished
+
+        return self.runInteraction(
+            "get_public_room_changes", get_public_room_changes_txn
+        )
+
+    @cached(max_entries=10000)
+    def is_room_blocked(self, room_id):
+        return self._simple_select_one_onecol(
+            table="blocked_rooms",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="1",
+            allow_none=True,
+            desc="is_room_blocked",
+        )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
 
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
@@ -227,16 +358,6 @@ class RoomStore(SQLBaseStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_public_room_ids(self):
-        return self._simple_select_onecol(
-            table="rooms",
-            keyvalues={
-                "is_public": True,
-            },
-            retcol="room_id",
-            desc="get_public_room_ids",
-        )
-
     def get_room_count(self):
         """Retrieve a list of all rooms
         """
@@ -263,8 +384,8 @@ class RoomStore(SQLBaseStore):
                 },
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"],
             )
 
     def _store_room_name_txn(self, txn, event):
@@ -279,14 +400,14 @@ class RoomStore(SQLBaseStore):
                 }
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"],
             )
 
     def _store_room_message_txn(self, txn, event):
         if hasattr(event, "content") and "body" in event.content:
-            self._store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"],
             )
 
     def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +429,6 @@ class RoomStore(SQLBaseStore):
                 event.content[key]
             ))
 
-    def _store_event_search_txn(self, txn, event, key, value):
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-            txn.execute(
-                sql,
-                (
-                    event.event_id, event.room_id, key, value,
-                    event.internal_metadata.stream_ordering,
-                    event.origin_server_ts,
-                )
-            )
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            txn.execute(sql, (event.event_id, event.room_id, key, value,))
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
     def add_event_report(self, room_id, event_id, user_id, reason, content,
                          received_ts):
         next_id = self._event_reports_id_gen.get_next()
@@ -353,113 +449,6 @@ class RoomStore(SQLBaseStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    @cached(num_args=2, max_entries=100)
-    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
-        """Get pulbic rooms for a particular list, or across all lists.
-
-        Args:
-            stream_id (int)
-            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
-                means the main list, None means all lsits.
-        """
-        return self.runInteraction(
-            "get_public_room_ids_at_stream_id",
-            self.get_public_room_ids_at_stream_id_txn,
-            stream_id, network_tuple=network_tuple
-        )
-
-    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
-                                             network_tuple):
-        return {
-            rm
-            for rm, vis in self.get_published_at_stream_id_txn(
-                txn, stream_id, network_tuple=network_tuple
-            ).items()
-            if vis
-        }
-
-    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
-        if network_tuple:
-            # We want to get from a particular list. No aggregation required.
-
-            sql = ("""
-                SELECT room_id, visibility FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT room_id, max(stream_id) AS stream_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ? %s
-                    GROUP BY room_id
-                ) grouped USING (room_id, stream_id)
-            """)
-
-            if network_tuple.appservice_id is not None:
-                txn.execute(
-                    sql % ("AND appservice_id = ? AND network_id = ?",),
-                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
-                )
-            else:
-                txn.execute(
-                    sql % ("AND appservice_id IS NULL",),
-                    (stream_id,)
-                )
-            return dict(txn)
-        else:
-            # We want to get from all lists, so we need to aggregate the results
-
-            logger.info("Executing full list")
-
-            sql = ("""
-                SELECT room_id, visibility
-                FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT
-                        room_id, max(stream_id) AS stream_id, appservice_id,
-                        network_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ?
-                    GROUP BY room_id, appservice_id, network_id
-                ) grouped USING (room_id, stream_id)
-            """)
-
-            txn.execute(
-                sql,
-                (stream_id,)
-            )
-
-            results = {}
-            # A room is visible if its visible on any list.
-            for room_id, visibility in txn:
-                results[room_id] = bool(visibility) or results.get(room_id, False)
-
-            return results
-
-    def get_public_room_changes(self, prev_stream_id, new_stream_id,
-                                network_tuple):
-        def get_public_room_changes_txn(txn):
-            then_rooms = self.get_public_room_ids_at_stream_id_txn(
-                txn, prev_stream_id, network_tuple
-            )
-
-            now_rooms_dict = self.get_published_at_stream_id_txn(
-                txn, new_stream_id, network_tuple
-            )
-
-            now_rooms_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if vis
-            )
-            now_rooms_not_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if not vis
-            )
-
-            newly_visible = now_rooms_visible - then_rooms
-            newly_unpublished = now_rooms_not_visible & then_rooms
-
-            return newly_visible, newly_unpublished
-
-        return self.runInteraction(
-            "get_public_room_changes", get_public_room_changes_txn
-        )
-
     def get_all_new_public_rooms(self, prev_id, current_id, limit):
         def get_all_new_public_rooms(txn):
             sql = ("""
@@ -509,18 +498,6 @@ class RoomStore(SQLBaseStore):
         else:
             defer.returnValue(None)
 
-    @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self._simple_select_one_onecol(
-            table="blocked_rooms",
-            keyvalues={
-                "room_id": room_id,
-            },
-            retcol="1",
-            allow_none=True,
-            desc="is_room_blocked",
-        )
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         yield self._simple_insert(
@@ -531,75 +508,121 @@ class RoomStore(SQLBaseStore):
             },
             desc="block_room",
         )
-        self.is_room_blocked.invalidate((room_id,))
+        yield self.runInteraction(
+            "block_room_invalidation",
+            self._invalidate_cache_and_stream,
+            self.is_room_blocked, (room_id,),
+        )
+
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
 
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
         the associated media
         """
-        def _get_media_ids_in_room(txn):
-            mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
 
-            next_token = self.get_current_events_token() + 1
+            # Now update all the tables to set the quarantined_by flag
 
-            total_media_quarantined = 0
+            txn.executemany("""
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """, ((quarantined_by, media_id) for media_id in local_mxcs))
 
-            while next_token:
-                sql = """
-                    SELECT stream_ordering, content FROM events
-                    WHERE room_id = ?
-                        AND stream_ordering < ?
-                        AND contains_url = ? AND outlier = ?
-                    ORDER BY stream_ordering DESC
-                    LIMIT ?
+            txn.executemany(
                 """
-                txn.execute(sql, (room_id, next_token, True, False, 100))
-
-                next_token = None
-                local_media_mxcs = []
-                remote_media_mxcs = []
-                for stream_ordering, content_json in txn:
-                    next_token = stream_ordering
-                    content = json.loads(content_json)
-
-                    content_url = content.get("url")
-                    thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                    for url in (content_url, thumbnail_url):
-                        if not url:
-                            continue
-                        matches = mxc_re.match(url)
-                        if matches:
-                            hostname = matches.group(1)
-                            media_id = matches.group(2)
-                            if hostname == self.hostname:
-                                local_media_mxcs.append(media_id)
-                            else:
-                                remote_media_mxcs.append((hostname, media_id))
-
-                # Now update all the tables to set the quarantined_by flag
-
-                txn.executemany("""
-                    UPDATE local_media_repository
+                    UPDATE remote_media_cache
                     SET quarantined_by = ?
-                    WHERE media_id = ?
-                """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
-                txn.executemany(
-                    """
-                        UPDATE remote_media_cache
-                        SET quarantined_by = ?
-                        WHERE media_origin AND media_id = ?
-                    """,
-                    (
-                        (quarantined_by, origin, media_id)
-                        for origin, media_id in remote_media_mxcs
-                    )
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
                 )
+            )
 
-                total_media_quarantined += len(local_media_mxcs)
-                total_media_quarantined += len(remote_media_mxcs)
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
 
             return total_media_quarantined
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+        return self.runInteraction(
+            "quarantine_media_in_room",
+            _quarantine_media_in_room_txn,
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        next_token = self.get_current_events_token() + 1
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while next_token:
+            sql = """
+                SELECT stream_ordering, json FROM events
+                JOIN event_json USING (room_id, event_id)
+                WHERE room_id = ?
+                    AND stream_ordering < ?
+                    AND contains_url = ? AND outlier = ?
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+            txn.execute(sql, (room_id, next_token, True, False, 100))
+
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                event_json = json.loads(content_json)
+                content = event_json["content"]
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hs.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+        return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 457ca288d0..01697ab2c9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,22 +14,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
+import logging
 from collections import namedtuple
 
-from ._base import SQLBaseStore
+from six import iteritems, itervalues
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.events import EventsWorkerStore
+from synapse.types import get_domain_from_id
 from synapse.util.async import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.stringutils import to_ascii
 
-from synapse.api.constants import Membership, EventTypes
-from synapse.types import get_domain_from_id
-
-import logging
-import ujson as json
-
 logger = logging.getLogger(__name__)
 
 
@@ -37,6 +39,11 @@ RoomsForUser = namedtuple(
     ("room_id", "sender", "membership", "event_id", "stream_ordering")
 )
 
+GetRoomsForUserWithStreamOrdering = namedtuple(
+    "_GetRoomsForUserWithStreamOrdering",
+    ("room_id", "stream_ordering",)
+)
+
 
 # We store this using a namedtuple so that we save about 3x space over using a
 # dict.
@@ -48,97 +55,7 @@ ProfileInfo = namedtuple(
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
 
-class RoomMemberStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(RoomMemberStore, self).__init__(hs)
-        self.register_background_update_handler(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
-        )
-
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ]
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key, event.internal_metadata.stream_ordering
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened.
-            # The only current event that can also be an outlier is if its an
-            # invite that has come in across federation.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_invite_from_remote()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        }
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(sql, (
-                        event.internal_metadata.stream_ordering,
-                        event.event_id,
-                        event.room_id,
-                        event.state_key,
-                    ))
-
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (
-                stream_ordering,
-                True,
-                room_id,
-                user_id,
-            ))
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
+class RoomMemberWorkerStore(EventsWorkerStore):
     @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
     def get_hosts_in_room(self, room_id, cache_context):
         """Returns the set of all hosts currently in the room
@@ -270,12 +187,32 @@ class RoomMemberStore(SQLBaseStore):
         return results
 
     @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_rooms_for_user(self, user_id):
+    def get_rooms_for_user_with_stream_ordering(self, user_id):
         """Returns a set of room_ids the user is currently joined to
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+            the rooms the user is in currently, along with the stream ordering
+            of the most recent join for that user and room.
         """
         rooms = yield self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
+        defer.returnValue(frozenset(
+            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+            for r in rooms
+        ))
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id, on_invalidate=None):
+        """Returns a set of room_ids the user is currently joined to
+        """
+        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+            user_id, on_invalidate=on_invalidate,
+        )
         defer.returnValue(frozenset(r.room_id for r in rooms))
 
     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
@@ -295,89 +232,7 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(user_who_share_room)
 
-    def forget(self, user_id, room_id):
-        """Indicate that user_id wishes to discard history for room_id."""
-        def f(txn):
-            sql = (
-                "UPDATE"
-                "  room_memberships"
-                " SET"
-                "  forgotten = 1"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id))
-
-            txn.call_after(self.was_forgotten_at.invalidate_all)
-            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
-            self._invalidate_cache_and_stream(
-                txn, self.who_forgot_in_room, (room_id,)
-            )
-        return self.runInteraction("forget_membership", f)
-
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
-        """Returns whether user_id has elected to discard history for room_id.
-
-        Returns False if they have since re-joined."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  COUNT(*)"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  forgotten = 0"
-            )
-            txn.execute(sql, (user_id, room_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        count = yield self.runInteraction("did_forget_membership", f)
-        defer.returnValue(count == 0)
-
-    @cachedInlineCallbacks(num_args=3)
-    def was_forgotten_at(self, user_id, room_id, event_id):
-        """Returns whether user_id has elected to discard history for room_id at
-        event_id.
-
-        event_id must be a membership event."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  forgotten"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  event_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id, event_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        forgot = yield self.runInteraction("did_forget_membership_at", f)
-        defer.returnValue(forgot == 1)
-
-    @cached()
-    def who_forgot_in_room(self, room_id):
-        return self._simple_select_list(
-            table="room_memberships",
-            retcols=("user_id", "event_id"),
-            keyvalues={
-                "room_id": room_id,
-                "forgotten": 1,
-            },
-            desc="who_forgot"
-        )
-
+    @defer.inlineCallbacks
     def get_joined_users_from_context(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -387,11 +242,13 @@ class RoomMemberStore(SQLBaseStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._get_joined_users_from_context(
-            event.room_id, state_group, context.current_state_ids,
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._get_joined_users_from_context(
+            event.room_id, state_group, current_state_ids,
             event=event,
             context=context,
         )
+        defer.returnValue(result)
 
     def get_joined_users_from_state(self, room_id, state_entry):
         state_group = state_entry.state_group
@@ -419,7 +276,7 @@ class RoomMemberStore(SQLBaseStore):
         users_in_room = {}
         member_event_ids = [
             e_id
-            for key, e_id in current_state_ids.iteritems()
+            for key, e_id in iteritems(current_state_ids)
             if key[0] == EventTypes.Member
         ]
 
@@ -436,7 +293,7 @@ class RoomMemberStore(SQLBaseStore):
                     users_in_room = dict(prev_res)
                     member_event_ids = [
                         e_id
-                        for key, e_id in context.delta_ids.iteritems()
+                        for key, e_id in iteritems(context.delta_ids)
                         if key[0] == EventTypes.Member
                     ]
                     for etype, state_key in context.delta_ids:
@@ -533,6 +390,46 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(True)
 
+    @cachedInlineCallbacks()
+    def was_host_joined(self, room_id, host):
+        """Check whether the server is or ever was in the room.
+
+        Args:
+            room_id (str)
+            host (str)
+
+        Returns:
+            Deferred: Resolves to True if the host is/was in the room, otherwise
+            False.
+        """
+        if '%' in host or '_' in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT user_id FROM room_memberships
+            WHERE room_id = ?
+                AND user_id LIKE ?
+                AND membership = 'join'
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            defer.returnValue(False)
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        defer.returnValue(True)
+
     def get_joined_hosts(self, room_id, state_entry):
         state_group = state_entry.state_group
         if not state_group:
@@ -560,6 +457,144 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(joined_hosts)
 
+    @cached(max_entries=10000)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
+
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ]
+        )
+
+        for event in events:
+            txn.call_after(
+                self._membership_stream_cache.entity_has_changed,
+                event.state_key, event.internal_metadata.stream_ordering
+            )
+            txn.call_after(
+                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+            )
+
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened.
+            # The only current event that can also be an outlier is if its an
+            # invite that has come in across federation.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_invite_from_remote()
+            )
+            is_mine = self.hs.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self._simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        }
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
+
+                    txn.execute(sql, (
+                        event.internal_metadata.stream_ordering,
+                        event.event_id,
+                        event.room_id,
+                        event.state_key,
+                    ))
+
+    @defer.inlineCallbacks
+    def locally_reject_invite(self, user_id, room_id):
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (
+                stream_ordering,
+                True,
+                room_id,
+                user_id,
+            ))
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+    def forget(self, user_id, room_id):
+        """Indicate that user_id wishes to discard history for room_id."""
+        def f(txn):
+            sql = (
+                "UPDATE"
+                "  room_memberships"
+                " SET"
+                "  forgotten = 1"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id))
+
+            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+        return self.runInteraction("forget_membership", f)
+
+    @cachedInlineCallbacks(num_args=2)
+    def did_forget(self, user_id, room_id):
+        """Returns whether user_id has elected to discard history for room_id.
+
+        Returns False if they have since re-joined."""
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  COUNT(*)"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  forgotten = 0"
+            )
+            txn.execute(sql, (user_id, room_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+        count = yield self.runInteraction("did_forget_membership", f)
+        defer.returnValue(count == 0)
+
     @defer.inlineCallbacks
     def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
@@ -573,8 +608,9 @@ class RoomMemberStore(SQLBaseStore):
 
         def add_membership_profile_txn(txn):
             sql = ("""
-                SELECT stream_ordering, event_id, events.room_id, content
+                SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
+                INNER JOIN event_json USING (event_id)
                 INNER JOIN room_memberships USING (event_id)
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                 AND type = 'm.room.member'
@@ -595,8 +631,9 @@ class RoomMemberStore(SQLBaseStore):
                 event_id = row["event_id"]
                 room_id = row["room_id"]
                 try:
-                    content = json.loads(row["content"])
-                except:
+                    event_json = json.loads(row["json"])
+                    content = event_json['content']
+                except Exception:
                     continue
 
                 display_name = content.get("displayname", None)
@@ -635,10 +672,6 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(result)
 
-    @cached(max_entries=10000, iterable=True)
-    def _get_joined_hosts_cache(self, room_id):
-        return _JoinedHostsCache(self, room_id)
-
 
 class _JoinedHostsCache(object):
     """Cache for joined hosts in a room that is optimised to handle updates
@@ -671,7 +704,7 @@ class _JoinedHostsCache(object):
             if state_entry.state_group == self.state_group:
                 pass
             elif state_entry.prev_group == self.state_group:
-                for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
                     if typ != EventTypes.Member:
                         continue
 
@@ -701,7 +734,7 @@ class _JoinedHostsCache(object):
                 self.state_group = state_entry.state_group
             else:
                 self.state_group = object()
-            self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
         defer.returnValue(frozenset(self.hosts_to_joined_users))
 
     def __len__(self):
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 8755bb2e49..4d725b92fe 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 
+import simplejson as json
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index 4269ac69ad..4b2ffd35fd 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -14,10 +14,10 @@
 
 import logging
 
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+import simplejson
 
-import ujson
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.prepare_database import get_statements
 
 logger = logging.getLogger(__name__)
 
@@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 71b12a2731..414f9f5aa0 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -14,9 +14,9 @@
 
 import logging
 
-from synapse.storage.prepare_database import get_statements
+import simplejson
 
-import ujson
+from synapse.storage.prepare_database import get_statements
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index 5b7d8d1ab5..ef7ec34346 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -12,8 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from synapse.config.appservice import load_appservices
 
+from six.moves import range
+
+from synapse.config.appservice import load_appservices
 
 logger = logging.getLogger(__name__)
 
@@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs):
     # NULL indicates user was not registered by an appservice.
     try:
         cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
-    except:
+    except Exception:
         # Maybe we already added the column? Hope so...
         pass
 
@@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
 
     for as_id, user_ids in owned.items():
         n = 100
-        user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+        user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
         for chunk in user_chunks:
             cur.execute(
                 database_engine.convert_param_style(
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index 470ae0c005..7d8ca5f93f 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -12,12 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+
+import simplejson
+
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.prepare_database import get_statements
 
-import logging
-import ujson
-
 logger = logging.getLogger(__name__)
 
 
@@ -49,7 +50,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "rows_inserted": 0,
             "have_added_indexes": False,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 83066cccc9..bff1256a7b 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -12,10 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.prepare_database import get_statements
-
 import logging
-import ujson
+
+import simplejson
+
+from synapse.storage.prepare_database import get_statements
 
 logger = logging.getLogger(__name__)
 
@@ -44,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
index 55ae43f395..9754d3ccfb 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -14,7 +14,6 @@
 
 import time
 
-
 ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
 
 
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py
index 3b63a1562d..cf09e43e2b 100644
--- a/synapse/storage/schema/delta/34/cache_stream.py
+++ b/synapse/storage/schema/delta/34/cache_stream.py
@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine
-
 import logging
 
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py
index 033144341c..67d505e68b 100644
--- a/synapse/storage/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/schema/delta/34/received_txn_purge.py
@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.engines import PostgresEngine
-
 import logging
 
+from synapse.storage.engines import PostgresEngine
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py
index 81948e3431..0ffab10b6f 100644
--- a/synapse/storage/schema/delta/34/sent_txn_purge.py
+++ b/synapse/storage/schema/delta/34/sent_txn_purge.py
@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.engines import PostgresEngine
-
 import logging
 
+from synapse.storage.engines import PostgresEngine
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py
index 20ad8bd5a6..a377884169 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/schema/delta/37/remove_auth_idx.py
@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine
-
 import logging
 
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
 logger = logging.getLogger(__name__)
 
 DROP_INDICES = """
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
index f090a7b75a..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
@@ -13,5 +13,7 @@
  * limitations under the License.
  */
 
- INSERT into background_updates (update_name, progress_json)
-     VALUES ('event_search_postgres_gist', '{}');
+-- We no longer do this given we back it out again in schema 47
+
+-- INSERT into background_updates (update_name, progress_json)
+--     VALUES ('event_search_postgres_gist', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
index ea6a18196d..506f326f4d 100644
--- a/synapse/storage/schema/delta/42/user_dir.py
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -14,8 +14,8 @@
 
 import logging
 
-from synapse.storage.prepare_database import get_statements
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.prepare_database import get_statements
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql
index 4501d90cbb..ee7062abe4 100644
--- a/synapse/storage/schema/delta/43/user_share.sql
+++ b/synapse/storage/schema/delta/43/user_share.sql
@@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
 CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
 
 
--- Make sure that we popualte the table initially
+-- Make sure that we populate the table initially
 UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
new file mode 100644
index 0000000000..b12f9b2ebf
--- /dev/null
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -0,0 +1,41 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
+-- removed and replaced with 46/local_media_repository_url_idx.sql.
+--
+-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+
+-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
+-- indices on expressions until 3.9.
+CREATE TABLE local_media_repository_url_cache_new(
+    url TEXT,
+    response_code INTEGER,
+    etag TEXT,
+    expires_ts BIGINT,
+    og TEXT,
+    media_id TEXT,
+    download_ts BIGINT
+);
+
+INSERT INTO local_media_repository_url_cache_new
+    SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
+
+DROP TABLE local_media_repository_url_cache;
+ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
+
+CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
+CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
+CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql
new file mode 100644
index 0000000000..b2333848a0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/group_server.sql
@@ -0,0 +1,167 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups (
+    group_id TEXT NOT NULL,
+    name TEXT,  -- the display name of the room
+    avatar_url TEXT,
+    short_description TEXT,
+    long_description TEXT
+);
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
+
+
+-- list of users the group server thinks are joined
+CREATE TABLE group_users (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    is_admin BOOLEAN NOT NULL,
+    is_public BOOLEAN NOT NULL  -- whether the users membership can be seen by everyone
+);
+
+
+CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
+CREATE INDEX groups_users_u_idx ON group_users(user_id);
+
+-- list of users the group server thinks are invited
+CREATE TABLE group_invites (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL
+);
+
+CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
+CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
+
+
+CREATE TABLE group_rooms (
+    group_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    is_public BOOLEAN NOT NULL  -- whether the room can be seen by everyone
+);
+
+CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
+CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
+
+
+-- Rooms to include in the summary
+CREATE TABLE group_summary_rooms (
+    group_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    category_id TEXT NOT NULL,
+    room_order BIGINT NOT NULL,
+    is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
+    UNIQUE (group_id, category_id, room_id, room_order),
+    CHECK (room_order > 0)
+);
+
+CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
+
+
+-- Categories to include in the summary
+CREATE TABLE group_summary_room_categories (
+    group_id TEXT NOT NULL,
+    category_id TEXT NOT NULL,
+    cat_order BIGINT NOT NULL,
+    UNIQUE (group_id, category_id, cat_order),
+    CHECK (cat_order > 0)
+);
+
+-- The categories in the group
+CREATE TABLE group_room_categories (
+    group_id TEXT NOT NULL,
+    category_id TEXT NOT NULL,
+    profile TEXT NOT NULL,
+    is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
+    UNIQUE (group_id, category_id)
+);
+
+-- The users to include in the group summary
+CREATE TABLE group_summary_users (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    role_id TEXT NOT NULL,
+    user_order BIGINT NOT NULL,
+    is_public BOOLEAN NOT NULL  -- whether the user should be show to everyone
+);
+
+CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
+
+-- The roles to include in the group summary
+CREATE TABLE group_summary_roles (
+    group_id TEXT NOT NULL,
+    role_id TEXT NOT NULL,
+    role_order BIGINT NOT NULL,
+    UNIQUE (group_id, role_id, role_order),
+    CHECK (role_order > 0)
+);
+
+
+-- The roles in a groups
+CREATE TABLE group_roles (
+    group_id TEXT NOT NULL,
+    role_id TEXT NOT NULL,
+    profile TEXT NOT NULL,
+    is_public BOOLEAN NOT NULL,  -- whether the role should be show to everyone
+    UNIQUE (group_id, role_id)
+);
+
+
+-- List of  attestations we've given out and need to renew
+CREATE TABLE group_attestations_renewals (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    valid_until_ms BIGINT NOT NULL
+);
+
+CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
+CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
+CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
+
+
+-- List of attestations we've received from remotes and are interested in.
+CREATE TABLE group_attestations_remote (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    valid_until_ms BIGINT NOT NULL,
+    attestation_json TEXT NOT NULL
+);
+
+CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
+CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
+CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
+
+
+-- The group membership for the HS's users
+CREATE TABLE local_group_membership (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    is_admin BOOLEAN NOT NULL,
+    membership TEXT NOT NULL,
+    is_publicised BOOLEAN NOT NULL,  -- if the user is publicising their membership
+    content TEXT NOT NULL
+);
+
+CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
+CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+    stream_id BIGINT NOT NULL,
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    content TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql
new file mode 100644
index 0000000000..e5ddc84df0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/profile_cache.sql
@@ -0,0 +1,28 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A subset of remote users whose profiles we have cached.
+-- Whether a user is in this table or not is defined by the storage function
+-- `is_subscribed_remote_profile_for_user`
+CREATE TABLE remote_profile_cache (
+    user_id TEXT NOT NULL,
+    displayname TEXT,
+    avatar_url TEXT,
+    last_check BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
+CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
index bb225dafbf..68c48a89a9 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
+++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2017 New Vector Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,5 +13,5 @@
  * limitations under the License.
  */
 
-INSERT INTO background_updates (update_name, progress_json) VALUES
-  ('refresh_tokens_device_index', '{}');
+/* we no longer use (or create) the refresh_tokens table */
+DROP TABLE IF EXISTS refresh_tokens;
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
new file mode 100644
index 0000000000..bb307889c1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- drop the unique constraint on deleted_pushers so that we can just insert
+-- into it rather than upserting.
+
+CREATE TABLE deleted_pushers2 (
+    stream_id BIGINT NOT NULL,
+    app_id TEXT NOT NULL,
+    pushkey TEXT NOT NULL,
+    user_id TEXT NOT NULL
+);
+
+INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
+    SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
+
+DROP TABLE deleted_pushers;
+ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
+
+-- create the index after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
+
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql
new file mode 100644
index 0000000000..097679bc9a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/group_server.sql
@@ -0,0 +1,32 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups_new (
+    group_id TEXT NOT NULL,
+    name TEXT,  -- the display name of the room
+    avatar_url TEXT,
+    short_description TEXT,
+    long_description TEXT,
+    is_public BOOL NOT NULL -- whether non-members can access group APIs
+);
+
+-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
+INSERT INTO groups_new
+    SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
+
+DROP TABLE groups;
+ALTER TABLE groups_new RENAME TO groups;
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
new file mode 100644
index 0000000000..bbfc7f5d1a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will recreate the
+-- local_media_repository_url_idx index.
+--
+-- We do this as a bg update not because it is a particularly onerous
+-- operation, but because we'd like it to be a partial index if possible, and
+-- the background_index_update code will understand whether we are on
+-- postgres or sqlite and behave accordingly.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+    ('local_media_repository_url_idx', '{}');
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
new file mode 100644
index 0000000000..cb0d5a2576
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- change the user_directory table to also cover global local user profiles
+-- rather than just profiles within specific rooms.
+
+CREATE TABLE user_directory2 (
+    user_id TEXT NOT NULL,
+    room_id TEXT,
+    display_name TEXT,
+    avatar_url TEXT
+);
+
+INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
+    SELECT user_id, room_id, display_name, avatar_url from user_directory;
+
+DROP TABLE user_directory;
+ALTER TABLE user_directory2 RENAME TO user_directory;
+
+-- create indexes after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql
new file mode 100644
index 0000000000..d9505f8da1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_typos.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this is just embarassing :|
+ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
+
+-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
+-- so is hopefully not going to block anyone else for that long...
+CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
+CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
+DROP INDEX users_in_pubic_room_room_idx;
+DROP INDEX users_in_pubic_room_user_idx;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/47/last_access_media.sql
index 290bd6da86..f505fb22b5 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device.sql
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2018 New Vector Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
  * limitations under the License.
  */
 
-ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
index 34db0cf12b..31d7a817eb 100644
--- a/synapse/storage/schema/delta/23/refresh_tokens.sql
+++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
@@ -1,4 +1,4 @@
-/* Copyright 2015, 2016 OpenMarket Ltd
+/* Copyright 2018 New Vector Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,9 +13,5 @@
  * limitations under the License.
  */
 
-CREATE TABLE IF NOT EXISTS refresh_tokens(
-    id INTEGER PRIMARY KEY,
-    token TEXT NOT NULL,
-    user_id TEXT NOT NULL,
-    UNIQUE (token)
-);
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('event_search_postgres_gin', '{}');
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql
new file mode 100644
index 0000000000..edccf4a96f
--- /dev/null
+++ b/synapse/storage/schema/delta/47/push_actions_staging.sql
@@ -0,0 +1,28 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Temporary staging area for push actions that have been calculated for an
+-- event, but the event hasn't yet been persisted.
+-- When the event is persisted the rows are moved over to the
+-- event_push_actions table.
+CREATE TABLE event_push_actions_staging (
+    event_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    actions TEXT NOT NULL,
+    notif SMALLINT NOT NULL,
+    highlight SMALLINT NOT NULL
+);
+
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        # if we already have some state groups, we want to start making new
+        # ones with a higher id.
+        cur.execute("SELECT max(id) FROM state_groups")
+        row = cur.fetchone()
+
+        if row[0] is None:
+            start_val = 1
+        else:
+            start_val = row[0] + 1
+
+        cur.execute(
+            "CREATE SEQUENCE state_group_id_seq START WITH %s",
+            (start_val, ),
+        )
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/schema/delta/48/add_user_consent.sql
new file mode 100644
index 0000000000..5237491506
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_consent.sql
@@ -0,0 +1,18 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record the version of the privacy policy the user has consented to
+ */
+ALTER TABLE users ADD COLUMN consent_version TEXT;
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
new file mode 100644
index 0000000000..9248b0b24a
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('user_ips_last_seen_index', '{}');
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/schema/delta/48/deactivated_users.sql
new file mode 100644
index 0000000000..e9013a6969
--- /dev/null
+++ b/synapse/storage/schema/delta/48/deactivated_users.sql
@@ -0,0 +1,25 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Store any accounts that have been requested to be deactivated.
+ * We part the account from all the rooms its in when its
+ * deactivated. This can take some time and synapse may be restarted
+ * before it completes, so store the user IDs here until the process
+ * is complete.
+ */
+CREATE TABLE users_pending_deactivation (
+    user_id TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
new file mode 100644
index 0000000000..2233af87d7
--- /dev/null
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -0,0 +1,57 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+FIX_INDEXES = """
+-- rebuild indexes as uniques
+DROP INDEX groups_invites_g_idx;
+CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
+DROP INDEX groups_users_g_idx;
+CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
+
+-- rename other indexes to actually match their table names..
+DROP INDEX groups_users_u_idx;
+CREATE INDEX group_users_u_idx ON group_users(user_id);
+DROP INDEX groups_invites_u_idx;
+CREATE INDEX group_invites_u_idx ON group_invites(user_id);
+DROP INDEX groups_rooms_g_idx;
+CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
+DROP INDEX groups_rooms_r_idx;
+CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+
+    # remove duplicates from group_users & group_invites tables
+    cur.execute("""
+        DELETE FROM group_users WHERE %s NOT IN (
+           SELECT min(%s) FROM group_users GROUP BY group_id, user_id
+        );
+    """ % (rowid, rowid))
+    cur.execute("""
+        DELETE FROM group_invites WHERE %s NOT IN (
+           SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
+        );
+    """ % (rowid, rowid))
+
+    for statement in get_statements(FIX_INDEXES.splitlines()):
+        cur.execute(statement)
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql
new file mode 100644
index 0000000000..ce26eaf0c9
--- /dev/null
+++ b/synapse/storage/schema/delta/48/groups_joinable.sql
@@ -0,0 +1,22 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* 
+ * This isn't a real ENUM because sqlite doesn't support it
+ * and we use a default of NULL for inserted rows and interpret
+ * NULL at the python store level as necessary so that existing
+ * rows are given the correct default policy.
+ */
+ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite';
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
new file mode 100644
index 0000000000..14dcf18d73
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
@@ -0,0 +1,20 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record whether we have sent a server notice about consenting to the
+ * privacy policy. Specifically records the version of the policy we sent
+ * a message about.
+ */
+ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT;
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
new file mode 100644
index 0000000000..3dd478196f
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL,
+                                 device_id TEXT,
+                                 timestamp BIGINT NOT NULL );
+CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
+CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
new file mode 100644
index 0000000000..3a4ed59b5b
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('user_ips_last_seen_only_index', '{}');
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
new file mode 100644
index 0000000000..c93ae47532
--- /dev/null
+++ b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('users_creation_ts', '{}');
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/schema/delta/50/erasure_store.sql
new file mode 100644
index 0000000000..5d8641a9ab
--- /dev/null
+++ b/synapse/storage/schema/delta/50/erasure_store.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of users who have requested that their details be erased
+CREATE TABLE erased_users (
+    user_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id);
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index a7ade69986..42e5cb6df5 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
     file TEXT NOT NULL,
     UNIQUE(version, file)
 );
+
+-- a list of schema files we have loaded on behalf of dynamic modules
+CREATE TABLE IF NOT EXISTS applied_module_schemas(
+    module_name TEXT NOT NULL,
+    file TEXT NOT NULL,
+    UNIQUE(module_name, file)
+);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 8f2b3c4435..d5b5df93e6 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,28 +13,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+import re
+from collections import namedtuple
+
+from six import string_types
+
+from canonicaljson import json
+
 from twisted.internet import defer
 
-from .background_updates import BackgroundUpdateStore
 from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import logging
-import re
-import ujson as json
-
+from .background_updates import BackgroundUpdateStore
 
 logger = logging.getLogger(__name__)
 
+SearchEntry = namedtuple('SearchEntry', [
+    'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+    'origin_server_ts',
+])
+
 
 class SearchStore(BackgroundUpdateStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
+    EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, hs):
-        super(SearchStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(SearchStore, self).__init__(db_conn, hs)
         self.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
@@ -42,23 +52,35 @@ class SearchStore(BackgroundUpdateStore):
             self.EVENT_SEARCH_ORDER_UPDATE_NAME,
             self._background_reindex_search_order
         )
-        self.register_background_update_handler(
+
+        # we used to have a background update to turn the GIN index into a
+        # GIST one; we no longer do that (obviously) because we actually want
+        # a GIN index. However, it's possible that some people might still have
+        # the background update queued, so we register a handler to clear the
+        # background update.
+        self.register_noop_background_update(
             self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
-            self._background_reindex_gist_search
+        )
+
+        self.register_background_update_handler(
+            self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
+            self._background_reindex_gin_search
         )
 
     @defer.inlineCallbacks
     def _background_reindex_search(self, progress, batch_size):
+        # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
         def reindex_search_txn(txn):
             sql = (
-                "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+                "SELECT stream_ordering, event_id, room_id, type, json, "
+                " origin_server_ts FROM events"
+                " JOIN event_json USING (room_id, event_id)"
                 " WHERE ? <= stream_ordering AND stream_ordering < ?"
                 " AND (%s)"
                 " ORDER BY stream_ordering DESC"
@@ -67,6 +89,10 @@ class SearchStore(BackgroundUpdateStore):
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
+            # we could stream straight from the results into
+            # store_search_entries_txn with a generator function, but that
+            # would mean having two cursors open on the database at once.
+            # Instead we just build a list of results.
             rows = self.cursor_to_dict(txn)
             if not rows:
                 return 0
@@ -79,9 +105,12 @@ class SearchStore(BackgroundUpdateStore):
                     event_id = row["event_id"]
                     room_id = row["room_id"]
                     etype = row["type"]
+                    stream_ordering = row["stream_ordering"]
+                    origin_server_ts = row["origin_server_ts"]
                     try:
-                        content = json.loads(row["content"])
-                    except:
+                        event_json = json.loads(row["json"])
+                        content = event_json["content"]
+                    except Exception:
                         continue
 
                     if etype == "m.room.message":
@@ -93,35 +122,28 @@ class SearchStore(BackgroundUpdateStore):
                     elif etype == "m.room.name":
                         key = "content.name"
                         value = content["name"]
+                    else:
+                        raise Exception("unexpected event type %s" % etype)
                 except (KeyError, AttributeError):
                     # If the event is missing a necessary field then
                     # skip over it.
                     continue
 
-                if not isinstance(value, basestring):
+                if not isinstance(value, string_types):
                     # If the event body, name or topic isn't a string
                     # then skip over it
                     continue
 
-                event_search_rows.append((event_id, room_id, key, value))
-
-            if isinstance(self.database_engine, PostgresEngine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, vector)"
-                    " VALUES (?,?,?,to_tsvector('english', ?))"
-                )
-            elif isinstance(self.database_engine, Sqlite3Engine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, value)"
-                    " VALUES (?,?,?,?)"
-                )
-            else:
-                # This should be unreachable.
-                raise Exception("Unrecognized database engine")
+                event_search_rows.append(SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event_id,
+                    room_id=room_id,
+                    stream_ordering=stream_ordering,
+                    origin_server_ts=origin_server_ts,
+                ))
 
-            for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
-                clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            self.store_search_entries_txn(txn, event_search_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -145,25 +167,48 @@ class SearchStore(BackgroundUpdateStore):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _background_reindex_gist_search(self, progress, batch_size):
+    def _background_reindex_gin_search(self, progress, batch_size):
+        """This handles old synapses which used GIST indexes, if any;
+        converting them back to be GIN as per the actual schema.
+        """
+
         def create_index(conn):
             conn.rollback()
-            conn.set_session(autocommit=True)
-            c = conn.cursor()
 
-            c.execute(
-                "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
-                " ON event_search USING GIST (vector)"
-            )
+            # we have to set autocommit, because postgres refuses to
+            # CREATE INDEX CONCURRENTLY without it.
+            conn.set_session(autocommit=True)
 
-            c.execute("DROP INDEX event_search_fts_idx")
+            try:
+                c = conn.cursor()
 
-            conn.set_session(autocommit=False)
+                # if we skipped the conversion to GIST, we may already/still
+                # have an event_search_fts_idx; unfortunately postgres 9.4
+                # doesn't support CREATE INDEX IF EXISTS so we just catch the
+                # exception and ignore it.
+                import psycopg2
+                try:
+                    c.execute(
+                        "CREATE INDEX CONCURRENTLY event_search_fts_idx"
+                        " ON event_search USING GIN (vector)"
+                    )
+                except psycopg2.ProgrammingError as e:
+                    logger.warn(
+                        "Ignoring error %r when trying to switch from GIST to GIN",
+                        e
+                    )
+
+                # we should now be able to delete the GIST index.
+                c.execute(
+                    "DROP INDEX IF EXISTS event_search_fts_idx_gist"
+                )
+            finally:
+                conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
             yield self.runWithConnection(create_index)
 
-        yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+        yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
         defer.returnValue(1)
 
     @defer.inlineCallbacks
@@ -242,6 +287,85 @@ class SearchStore(BackgroundUpdateStore):
 
         defer.returnValue(num_rows)
 
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
+
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
+        """
+        self.store_search_entries_txn(
+            txn,
+            (SearchEntry(
+                key=key,
+                value=value,
+                event_id=event.event_id,
+                room_id=event.room_id,
+                stream_ordering=event.internal_metadata.stream_ordering,
+                origin_server_ts=event.origin_server_ts,
+            ),),
+        )
+
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+                entry.stream_ordering, entry.origin_server_ts,
+            ) for entry in entries)
+
+            # inserts to a GIN index are normally batched up into a pending
+            # list, and then all committed together once the list gets to a
+            # certain size. The trouble with that is that postgres (pre-9.5)
+            # uses work_mem to determine the length of the list, and work_mem
+            # is typically very large.
+            #
+            # We therefore reduce work_mem while we do the insert.
+            #
+            # (postgres 9.5 uses the separate gin_pending_list_limit setting,
+            # so doesn't suffer the same problem, but changing work_mem will
+            # be harmless)
+            #
+            # Note that we don't need to worry about restoring it on
+            # exception, because exceptions will cause the transaction to be
+            # rolled back, including the effects of the SET command.
+            #
+            # Also: we use SET rather than SET LOCAL because there's lots of
+            # other stuff going on in this transaction, which want to have the
+            # normal work_mem setting.
+
+            txn.execute("SET work_mem='256kB'")
+            txn.executemany(sql, args)
+            txn.execute("RESET work_mem")
+
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+            ) for entry in entries)
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
@@ -326,7 +450,7 @@ class SearchStore(BackgroundUpdateStore):
             "search_msgs", self.cursor_to_dict, sql, *args
         )
 
-        results = filter(lambda row: row["room_id"] in room_ids, results)
+        results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
         events = yield self._get_events([r["event_id"] for r in results])
 
@@ -407,7 +531,7 @@ class SearchStore(BackgroundUpdateStore):
                 origin_server_ts, stream = pagination_token.split(",")
                 origin_server_ts = int(origin_server_ts)
                 stream = int(stream)
-            except:
+            except Exception:
                 raise SynapseError(400, "Invalid pagination token")
 
             clauses.append(
@@ -481,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
             "search_rooms", self.cursor_to_dict, sql, *args
         )
 
-        results = filter(lambda row: row["room_id"] in room_ids, results)
+        results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
         events = yield self._get_events([r["event_id"] for r in results])
 
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 67d5d9969a..470212aa2a 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -13,21 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
-from ._base import SQLBaseStore
+import six
 
 from unpaddedbase64 import encode_base64
+
+from twisted.internet import defer
+
 from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.util.caches.descriptors import cached, cachedList
 
+from ._base import SQLBaseStore
+
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+    db_binary_type = buffer
+else:
+    db_binary_type = memoryview
 
-class SignatureStore(SQLBaseStore):
-    """Persistence for event signatures and hashes"""
 
+class SignatureWorkerStore(SQLBaseStore):
     @cached()
     def get_event_reference_hash(self, event_id):
-        return self._get_event_reference_hashes_txn(event_id)
+        # This is a dummy function to allow get_event_reference_hashes
+        # to use its cache
+        raise NotImplementedError()
 
     @cachedList(cached_method_name="get_event_reference_hash",
                 list_name="event_ids", num_args=1)
@@ -56,7 +66,7 @@ class SignatureStore(SQLBaseStore):
             for e_id, h in hashes.items()
         }
 
-        defer.returnValue(hashes.items())
+        defer.returnValue(list(hashes.items()))
 
     def _get_event_reference_hashes_txn(self, txn, event_id):
         """Get all the hashes for a given PDU.
@@ -74,6 +84,10 @@ class SignatureStore(SQLBaseStore):
         txn.execute(query, (event_id, ))
         return {k: v for k, v in txn}
 
+
+class SignatureStore(SignatureWorkerStore):
+    """Persistence for event signatures and hashes"""
+
     def _store_event_reference_hashes_txn(self, txn, events):
         """Store a hash for a PDU
         Args:
@@ -87,7 +101,7 @@ class SignatureStore(SQLBaseStore):
             vals.append({
                 "event_id": event.event_id,
                 "algorithm": ref_alg,
-                "hash": buffer(ref_hash_bytes),
+                "hash": db_binary_type(ref_hash_bytes),
             })
 
         self._simple_insert_many_txn(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5673e4aa96..89a05c4618 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,16 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches import intern_string
-from synapse.util.stringutils import to_ascii
-from synapse.storage.engines import PostgresEngine
+import logging
+from collections import namedtuple
+
+from six import iteritems, itervalues
+from six.moves import range
 
 from twisted.internet import defer
-from collections import namedtuple
 
-import logging
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches import get_cache_factor_for, intern_string
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.stringutils import to_ascii
+
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -40,45 +46,19 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
         return len(self.delta_ids) if self.delta_ids else 0
 
 
-class StateStore(SQLBaseStore):
-    """ Keeps track of the state at a given event.
-
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
-
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
-
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
+class StateGroupWorkerStore(SQLBaseStore):
+    """The parts of StateGroupStore that can be called from workers.
     """
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
-    def __init__(self, hs):
-        super(StateStore, self).__init__(hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME,
-            self._background_index_state,
-        )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
+    def __init__(self, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
         )
 
     @cached(max_entries=100000, iterable=True)
@@ -158,12 +138,26 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.itervalues())
+        groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups)
 
         defer.returnValue(group_to_state)
 
     @defer.inlineCallbacks
+    def get_state_ids_for_group(self, state_group):
+        """Get the state IDs for the given state group
+
+        Args:
+            state_group (int)
+
+        Returns:
+            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+        """
+        group_to_state = yield self._get_state_for_groups((state_group,))
+
+        defer.returnValue(group_to_state[state_group])
+
+    @defer.inlineCallbacks
     def get_state_groups(self, room_id, event_ids):
         """ Get the state groups for the given list of event_ids
 
@@ -176,199 +170,27 @@ class StateStore(SQLBaseStore):
 
         state_event_map = yield self.get_events(
             [
-                ev_id for group_ids in group_to_ids.itervalues()
-                for ev_id in group_ids.itervalues()
+                ev_id for group_ids in itervalues(group_to_ids)
+                for ev_id in itervalues(group_ids)
             ],
             get_prev_content=False
         )
 
         defer.returnValue({
             group: [
-                state_event_map[v] for v in event_id_map.itervalues()
+                state_event_map[v] for v in itervalues(event_id_map)
                 if v in state_event_map
             ]
-            for group, event_id_map in group_to_ids.iteritems()
+            for group, event_id_map in iteritems(group_to_ids)
         })
 
-    def _have_persisted_state_group_txn(self, txn, state_group):
-        txn.execute(
-            "SELECT count(*) FROM state_groups WHERE id = ?",
-            (state_group,)
-        )
-        row = txn.fetchone()
-        return row and row[0]
-
-    def _store_mult_state_groups_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
-
-            if context.current_state_ids is None:
-                # AFAIK, this can never happen
-                logger.error(
-                    "Non-outlier event %s had current_state_ids==None",
-                    event.event_id)
-                continue
-
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.prev_group
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-            if self._have_persisted_state_group_txn(txn, context.state_group):
-                continue
-
-            self._simple_insert_txn(
-                txn,
-                table="state_groups",
-                values={
-                    "id": context.state_group,
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
-                },
-            )
-
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if context.prev_group:
-                is_in_db = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": context.prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (context.prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(
-                    txn, context.prev_group
-                )
-            if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
-                    },
-                )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in context.delta_ids.iteritems()
-                    ],
-                )
-            else:
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in context.current_state_ids.iteritems()
-                    ],
-                )
-
-            # Prefill the state group cache with this group.
-            # It's fine to use the sequence like this as the state group map
-            # is immutable. (If the map wasn't immutable then this prefill could
-            # race with another update)
-            txn.call_after(
-                self._state_group_cache.update,
-                self._state_group_cache.sequence,
-                key=context.state_group,
-                value=dict(context.current_state_ids),
-                full=True,
-            )
-
-        self._simple_insert_many_txn(
-            txn,
-            table="event_to_state_groups",
-            values=[
-                {
-                    "state_group": state_group_id,
-                    "event_id": event_id,
-                }
-                for event_id, state_group_id in state_groups.iteritems()
-            ],
-        )
-
-        for event_id, state_group_id in state_groups.iteritems():
-            txn.call_after(
-                self._get_state_group_for_event.prefill,
-                (event_id,), state_group_id
-            )
-
-    def _count_state_group_hops_txn(self, txn, state_group):
-        """Given a state group, count how many hops there are in the tree.
-
-        This is used to ensure the delta chains don't get too long.
-        """
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = ("""
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT count(*) FROM state;
-            """)
-
-            txn.execute(sql, (state_group,))
-            row = txn.fetchone()
-            if row and row[0]:
-                return row[0]
-            else:
-                return 0
-        else:
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
-            count = 0
-
-            while next_group:
-                next_group = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_group_edges",
-                    keyvalues={"state_group": next_group},
-                    retcol="prev_state_group",
-                    allow_none=True,
-                )
-                if next_group:
-                    count += 1
-
-            return count
-
     @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, types):
         """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
         results = {}
 
-        chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+        chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
@@ -422,6 +244,9 @@ class StateStore(SQLBaseStore):
                     (
                         "AND type = ? AND state_key = ?",
                         (etype, state_key)
+                    ) if state_key is not None else (
+                        "AND type = ?",
+                        (etype,)
                     )
                     for etype, state_key in types
                 ]
@@ -441,10 +266,19 @@ class StateStore(SQLBaseStore):
                         key = (typ, state_key)
                         results[group][key] = event_id
         else:
+            where_args = []
+            where_clauses = []
+            wildcard_types = False
             if types is not None:
-                where_clause = "AND (%s)" % (
-                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                )
+                for typ in types:
+                    if typ[1] is None:
+                        where_clauses.append("(type = ?)")
+                        where_args.append(typ[0])
+                        wildcard_types = True
+                    else:
+                        where_clauses.append("(type = ? AND state_key = ?)")
+                        where_args.extend([typ[0], typ[1]])
+                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
 
@@ -461,7 +295,7 @@ class StateStore(SQLBaseStore):
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
                     if types:
-                        args.extend(i for typ in types for i in typ)
+                        args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
@@ -474,9 +308,17 @@ class StateStore(SQLBaseStore):
                         if (typ, state_key) not in results[group]
                     )
 
-                    # If the lengths match then we must have all the types,
-                    # so no need to go walk further down the tree.
-                    if types is not None and len(results[group]) == len(types):
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        types is not None and
+                        not wildcard_types and
+                        len(results[group]) == len(types)
+                    ):
                         break
 
                     next_group = self._simple_select_one_onecol_txn(
@@ -509,21 +351,21 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.itervalues())
+        groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         state_event_map = yield self.get_events(
-            [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
+            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
             get_prev_content=False
         )
 
         event_to_state = {
             event_id: {
                 k: state_event_map[v]
-                for k, v in group_to_state[group].iteritems()
+                for k, v in iteritems(group_to_state[group])
                 if v in state_event_map
             }
-            for event_id, group in event_to_groups.iteritems()
+            for event_id, group in iteritems(event_to_groups)
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -546,12 +388,12 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.itervalues())
+        groups = set(itervalues(event_to_groups))
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         event_to_state = {
             event_id: group_to_state[group]
-            for event_id, group in event_to_groups.iteritems()
+            for event_id, group in iteritems(event_to_groups)
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -665,7 +507,7 @@ class StateStore(SQLBaseStore):
         got_all = is_all or not missing_types
 
         return {
-            k: v for k, v in state_dict_ids.iteritems()
+            k: v for k, v in iteritems(state_dict_ids)
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -685,10 +527,23 @@ class StateStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):
-        """Given list of groups returns dict of group -> list of state events
-        with matching types. `types` is a list of `(type, state_key)`, where
-        a `state_key` of None matches all state_keys. If `types` is None then
-        all events are returned.
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            types (None|iterable[(str, None|str)]):
+                indicates the state type/keys required. If None, the whole
+                state is fetched and returned.
+
+                Otherwise, each entry should be a `(type, state_key)` tuple to
+                include in the response. A `state_key` of None is a wildcard
+                meaning that we require all state with that type.
+
+        Returns:
+            Deferred[dict[int, dict[(type, state_key), EventBase]]]
+                a dictionary mapping from state group to state dictionary.
         """
         if types:
             types = frozenset(types)
@@ -697,7 +552,7 @@ class StateStore(SQLBaseStore):
         if types is not None:
             for group in set(groups):
                 state_dict_ids, _, got_all = self._get_some_state_from_cache(
-                    group, types
+                    group, types,
                 )
                 results[group] = state_dict_ids
 
@@ -718,32 +573,266 @@ class StateStore(SQLBaseStore):
             # Okay, so we have some missing_types, lets fetch them.
             cache_seq_num = self._state_group_cache.sequence
 
+            # the DictionaryCache knows if it has *all* the state, but
+            # does not know if it has all of the keys of a particular type,
+            # which makes wildcard lookups expensive unless we have a complete
+            # cache. Hence, if we are doing a wildcard lookup, populate the
+            # cache fully so that we can do an efficient lookup next time.
+
+            if types and any(k is None for (t, k) in types):
+                types_to_fetch = None
+            else:
+                types_to_fetch = types
+
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types
+                missing_groups, types_to_fetch,
             )
 
-            # Now we want to update the cache with all the things we fetched
-            # from the database.
-            for group, group_state_dict in group_to_state_dict.iteritems():
+            for group, group_state_dict in iteritems(group_to_state_dict):
                 state_dict = results[group]
 
-                state_dict.update(
-                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
-                    for k, v in group_state_dict.iteritems()
-                )
-
+                # update the result, filtering by `types`.
+                if types:
+                    for k, v in iteritems(group_state_dict):
+                        (typ, _) = k
+                        if k in types or (typ, None) in types:
+                            state_dict[k] = v
+                else:
+                    state_dict.update(group_state_dict)
+
+                # update the cache with all the things we fetched from the
+                # database.
                 self._state_group_cache.update(
                     cache_seq_num,
                     key=group,
-                    value=state_dict,
-                    full=(types is None),
-                    known_absent=types,
+                    value=group_state_dict,
+                    fetched_keys=types_to_fetch,
                 )
 
         defer.returnValue(results)
 
-    def get_next_state_group(self):
-        return self._state_groups_id_gen.get_next()
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        """Store a new set of state, returning a newly assigned state group.
+
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
+
+        Returns:
+            Deferred[int]: The state group ID
+        """
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
+                # AFAIK, this can never happen
+                raise Exception("current_state_ids cannot be None")
+
+            state_group = self.database_engine.get_next_state_group_id(txn)
+
+            self._simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={
+                    "id": state_group,
+                    "room_id": room_id,
+                    "event_id": event_id,
+                },
+            )
+
+            # We persist as a delta if we can, while also ensuring the chain
+            # of deltas isn't tooo long, as otherwise read performance degrades.
+            if prev_group:
+                is_in_db = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_groups",
+                    keyvalues={"id": prev_group},
+                    retcol="id",
+                    allow_none=True,
+                )
+                if not is_in_db:
+                    raise Exception(
+                        "Trying to persist state with unpersisted prev_group: %r"
+                        % (prev_group,)
+                    )
+
+                potential_hops = self._count_state_group_hops_txn(
+                    txn, prev_group
+                )
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={
+                        "state_group": state_group,
+                        "prev_state_group": prev_group,
+                    },
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(delta_ids)
+                    ],
+                )
+            else:
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(current_state_ids)
+                    ],
+                )
+
+            # Prefill the state group cache with this group.
+            # It's fine to use the sequence like this as the state group map
+            # is immutable. (If the map wasn't immutable then this prefill could
+            # race with another update)
+            txn.call_after(
+                self._state_group_cache.update,
+                self._state_group_cache.sequence,
+                key=state_group,
+                value=dict(current_state_ids),
+            )
+
+            return state_group
+
+        return self.runInteraction("store_state_group", _store_state_group_txn)
+
+    def _count_state_group_hops_txn(self, txn, state_group):
+        """Given a state group, count how many hops there are in the tree.
+
+        This is used to ensure the delta chains don't get too long.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = ("""
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """)
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+    def __init__(self, db_conn, hs):
+        super(StateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME,
+            self._background_index_state,
+        )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+
+    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
+        self._simple_insert_many_txn(
+            txn,
+            table="event_to_state_groups",
+            values=[
+                {
+                    "state_group": state_group_id,
+                    "event_id": event_id,
+                }
+                for event_id, state_group_id in iteritems(state_groups)
+            ],
+        )
+
+        for event_id, state_group_id in iteritems(state_groups):
+            txn.call_after(
+                self._get_state_group_for_event.prefill,
+                (event_id,), state_group_id
+            )
 
     @defer.inlineCallbacks
     def _background_deduplicate_state(self, progress, batch_size):
@@ -767,7 +856,7 @@ class StateStore(SQLBaseStore):
 
         def reindex_txn(txn):
             new_last_state_group = last_state_group
-            for count in xrange(batch_size):
+            for count in range(batch_size):
                 txn.execute(
                     "SELECT id, room_id FROM state_groups"
                     " WHERE ? < id AND id <= ?"
@@ -825,7 +914,7 @@ class StateStore(SQLBaseStore):
                         # of keys
 
                         delta_state = {
-                            key: value for key, value in curr_state.iteritems()
+                            key: value for key, value in iteritems(curr_state)
                             if prev_state.get(key, None) != value
                         }
 
@@ -865,7 +954,7 @@ class StateStore(SQLBaseStore):
                                     "state_key": key[1],
                                     "event_id": state_id,
                                 }
-                                for key, state_id in delta_state.iteritems()
+                                for key, state_id in iteritems(delta_state)
                             ],
                         )
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index dddd5fc0e7..66856342f0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -33,17 +33,20 @@ what sort order was used:
       and stream ordering columns respectively.
 """
 
-from twisted.internet import defer
+import abc
+import logging
+from collections import namedtuple
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-from synapse.api.constants import EventTypes
-from synapse.types import RoomStreamToken
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from six.moves import range
 
-import logging
+from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.events import EventsWorkerStore
+from synapse.types import RoomStreamToken
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
 logger = logging.getLogger(__name__)
 
@@ -55,6 +58,12 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
+# Used as return values for pagination APIs
+_EventDictReturn = namedtuple("_EventDictReturn", (
+    "event_id", "topological_ordering", "stream_ordering",
+))
+
+
 def lower_bound(token, engine, inclusive=False):
     inclusive = "=" if inclusive else ""
     if token.topological is None:
@@ -143,81 +152,41 @@ def filter_to_clause(event_filter):
     return " AND ".join(clauses), args
 
 
-class StreamStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
-        # NB this lives here instead of appservice.py so we can reuse the
-        # 'private' StreamToken class in this file.
-        if limit:
-            limit = max(limit, MAX_STREAM_SIZE)
-        else:
-            limit = MAX_STREAM_SIZE
-
-        # From and to keys should be integers from ordering.
-        from_id = RoomStreamToken.parse_stream_token(from_key)
-        to_id = RoomStreamToken.parse_stream_token(to_key)
-
-        if from_key == to_key:
-            defer.returnValue(([], to_key))
-            return
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
+    which can be called in the initializer.
+    """
 
-        # select all the events between from/to with a sensible limit
-        sql = (
-            "SELECT e.event_id, e.room_id, e.type, s.state_key, "
-            "e.stream_ordering FROM events AS e "
-            "LEFT JOIN state_events as s ON "
-            "e.event_id = s.event_id "
-            "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
-            "ORDER BY stream_ordering ASC LIMIT %(limit)d "
-        ) % {
-            "limit": limit
-        }
+    __metaclass__ = abc.ABCMeta
 
-        def f(txn):
-            # pull out all the events between the tokens
-            txn.execute(sql, (from_id.stream, to_id.stream,))
-            rows = self.cursor_to_dict(txn)
-
-            # Logic:
-            #  - We want ALL events which match the AS room_id regex
-            #  - We want ALL events which match the rooms represented by the AS
-            #    room_alias regex
-            #  - We want ALL events for rooms that AS users have joined.
-            # This is currently supported via get_app_service_rooms (which is
-            # used for the Notifier listener rooms). We can't reasonably make a
-            # SQL query for these room IDs, so we'll pull all the events between
-            # from/to and filter in python.
-            rooms_for_as = self._get_app_service_rooms_txn(txn, service)
-            room_ids_for_as = [r.room_id for r in rooms_for_as]
-
-            def app_service_interested(row):
-                if row["room_id"] in room_ids_for_as:
-                    return True
-
-                if row["type"] == EventTypes.Member:
-                    if service.is_interested_in_user(row.get("state_key")):
-                        return True
-                return False
-
-            return [r for r in rows if app_service_interested(r)]
-
-        rows = yield self.runInteraction("get_appservice_room_stream", f)
+    def __init__(self, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(db_conn, hs)
 
-        ret = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
+        events_max = self.get_room_max_stream_ordering()
+        event_cache_prefill, min_event_val = self._get_cache_dict(
+            db_conn, "events",
+            entity_column="room_id",
+            stream_column="stream_ordering",
+            max_value=events_max,
+        )
+        self._events_stream_cache = StreamChangeCache(
+            "EventsRoomStreamChangeCache", min_event_val,
+            prefilled_cache=event_cache_prefill,
+        )
+        self._membership_stream_cache = StreamChangeCache(
+            "MembershipStreamChangeCache", events_max,
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._stream_order_on_start = self.get_room_max_stream_ordering()
 
-        if rows:
-            key = "s%d" % max(r["stream_ordering"] for r in rows)
-        else:
-            # Assume we didn't get anything because there was nothing to
-            # get.
-            key = to_key
+    @abc.abstractmethod
+    def get_room_max_stream_ordering(self):
+        raise NotImplementedError()
 
-        defer.returnValue((ret, key))
+    @abc.abstractmethod
+    def get_room_min_stream_ordering(self):
+        raise NotImplementedError()
 
     @defer.inlineCallbacks
     def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -233,13 +202,14 @@ class StreamStore(SQLBaseStore):
 
         results = {}
         room_ids = list(room_ids)
-        for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
-            res = yield preserve_context_over_deferred(defer.gatherResults([
-                preserve_fn(self.get_room_events_stream_for_room)(
+        for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
+            res = yield make_deferred_yieldable(defer.gatherResults([
+                run_in_background(
+                    self.get_room_events_stream_for_room,
                     room_id, from_key, to_key, limit, order=order,
                 )
                 for room_id in rm_ids
-            ]))
+            ], consumeErrors=True))
             results.update(dict(zip(rm_ids, res)))
 
         defer.returnValue(results)
@@ -261,54 +231,55 @@ class StreamStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
                                         order='DESC'):
-        # Note: If from_key is None then we return in topological order. This
-        # is because in that case we're using this as a "get the last few messages
-        # in a room" function, rather than "get new messages since last sync"
-        if from_key is not None:
-            from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        else:
-            from_id = None
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
+        """Get new room events in stream ordering since `from_key`.
+
+        Args:
+            room_id (str)
+            from_key (str): Token from which no events are returned before
+            to_key (str): Token from which no events are returned after. (This
+                is typically the current stream token)
+            limit (int): Maximum number of events to return
+            order (str): Either "DESC" or "ASC". Determines which events are
+                returned when the result is limited. If "DESC" then the most
+                recent `limit` events are returned, otherwise returns the
+                oldest `limit` events.
+
+        Returns:
+            Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
+            events (in ascending order) and the token from the start of
+            the chunk of events returned.
+        """
         if from_key == to_key:
             defer.returnValue(([], from_key))
 
-        if from_id:
-            has_changed = yield self._events_stream_cache.has_entity_changed(
-                room_id, from_id
-            )
-
-            if not has_changed:
-                defer.returnValue(([], from_key))
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
-        def f(txn):
-            if from_id is not None:
-                sql = (
-                    "SELECT event_id, stream_ordering FROM events WHERE"
-                    " room_id = ?"
-                    " AND not outlier"
-                    " AND stream_ordering > ? AND stream_ordering <= ?"
-                    " ORDER BY stream_ordering %s LIMIT ?"
-                ) % (order,)
-                txn.execute(sql, (room_id, from_id, to_id, limit))
-            else:
-                sql = (
-                    "SELECT event_id, stream_ordering FROM events WHERE"
-                    " room_id = ?"
-                    " AND not outlier"
-                    " AND stream_ordering <= ?"
-                    " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
-                ) % (order, order,)
-                txn.execute(sql, (room_id, to_id, limit))
+        has_changed = yield self._events_stream_cache.has_entity_changed(
+            room_id, from_id
+        )
 
-            rows = self.cursor_to_dict(txn)
+        if not has_changed:
+            defer.returnValue(([], from_key))
 
+        def f(txn):
+            sql = (
+                "SELECT event_id, stream_ordering FROM events WHERE"
+                " room_id = ?"
+                " AND not outlier"
+                " AND stream_ordering > ? AND stream_ordering <= ?"
+                " ORDER BY stream_ordering %s LIMIT ?"
+            ) % (order,)
+            txn.execute(sql, (room_id, from_id, to_id, limit))
+
+            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             return rows
 
         rows = yield self.runInteraction("get_room_events_stream_for_room", f)
 
         ret = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )
 
@@ -318,7 +289,7 @@ class StreamStore(SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = "s%d" % min(r["stream_ordering"] for r in rows)
+            key = "s%d" % min(r.stream_ordering for r in rows)
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -328,10 +299,7 @@ class StreamStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_membership_changes_for_user(self, user_id, from_key, to_key):
-        if from_key is not None:
-            from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        else:
-            from_id = None
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
         if from_key == to_key:
@@ -345,34 +313,24 @@ class StreamStore(SQLBaseStore):
                 defer.returnValue([])
 
         def f(txn):
-            if from_id is not None:
-                sql = (
-                    "SELECT m.event_id, stream_ordering FROM events AS e,"
-                    " room_memberships AS m"
-                    " WHERE e.event_id = m.event_id"
-                    " AND m.user_id = ?"
-                    " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                    " ORDER BY e.stream_ordering ASC"
-                )
-                txn.execute(sql, (user_id, from_id, to_id,))
-            else:
-                sql = (
-                    "SELECT m.event_id, stream_ordering FROM events AS e,"
-                    " room_memberships AS m"
-                    " WHERE e.event_id = m.event_id"
-                    " AND m.user_id = ?"
-                    " AND stream_ordering <= ?"
-                    " ORDER BY stream_ordering ASC"
-                )
-                txn.execute(sql, (user_id, to_id,))
-            rows = self.cursor_to_dict(txn)
+            sql = (
+                "SELECT m.event_id, stream_ordering FROM events AS e,"
+                " room_memberships AS m"
+                " WHERE e.event_id = m.event_id"
+                " AND m.user_id = ?"
+                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+                " ORDER BY e.stream_ordering ASC"
+            )
+            txn.execute(sql, (user_id, from_id, to_id,))
+
+            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
 
             return rows
 
         rows = yield self.runInteraction("get_membership_changes_for_user", f)
 
         ret = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )
 
@@ -381,96 +339,28 @@ class StreamStore(SQLBaseStore):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def paginate_room_events(self, room_id, from_key, to_key=None,
-                             direction='b', limit=-1, event_filter=None):
-        # Tokens really represent positions between elements, but we use
-        # the convention of pointing to the event before the gap. Hence
-        # we have a bit of asymmetry when it comes to equalities.
-        args = [False, room_id]
-        if direction == 'b':
-            order = "DESC"
-            bounds = upper_bound(
-                RoomStreamToken.parse(from_key), self.database_engine
-            )
-            if to_key:
-                bounds = "%s AND %s" % (bounds, lower_bound(
-                    RoomStreamToken.parse(to_key), self.database_engine
-                ))
-        else:
-            order = "ASC"
-            bounds = lower_bound(
-                RoomStreamToken.parse(from_key), self.database_engine
-            )
-            if to_key:
-                bounds = "%s AND %s" % (bounds, upper_bound(
-                    RoomStreamToken.parse(to_key), self.database_engine
-                ))
-
-        filter_clause, filter_args = filter_to_clause(event_filter)
-
-        if filter_clause:
-            bounds += " AND " + filter_clause
-            args.extend(filter_args)
-
-        if int(limit) > 0:
-            args.append(int(limit))
-            limit_str = " LIMIT ?"
-        else:
-            limit_str = ""
-
-        sql = (
-            "SELECT * FROM events"
-            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
-            " ORDER BY topological_ordering %(order)s,"
-            " stream_ordering %(order)s %(limit)s"
-        ) % {
-            "bounds": bounds,
-            "order": order,
-            "limit": limit_str
-        }
-
-        def f(txn):
-            txn.execute(sql, args)
-
-            rows = self.cursor_to_dict(txn)
-
-            if rows:
-                topo = rows[-1]["topological_ordering"]
-                toke = rows[-1]["stream_ordering"]
-                if direction == 'b':
-                    # Tokens are positions between events.
-                    # This token points *after* the last event in the chunk.
-                    # We need it to point to the event before it in the chunk
-                    # when we are going backwards so we subtract one from the
-                    # stream part.
-                    toke -= 1
-                next_token = str(RoomStreamToken(topo, toke))
-            else:
-                # TODO (erikj): We should work out what to do here instead.
-                next_token = to_key if to_key else from_key
+    def get_recent_events_for_room(self, room_id, limit, end_token):
+        """Get the most recent events in the room in topological ordering.
 
-            return rows, next_token,
-
-        rows, token = yield self.runInteraction("paginate_room_events", f)
-
-        events = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
-        )
-
-        self._set_before_and_after(events, rows)
+        Args:
+            room_id (str)
+            limit (int)
+            end_token (str): The stream token representing now.
 
-        defer.returnValue((events, token))
+        Returns:
+            Deferred[tuple[list[FrozenEvent],  str]]: Returns a list of
+            events and a token pointing to the start of the returned
+            events.
+            The events returned are in ascending order.
+        """
 
-    @defer.inlineCallbacks
-    def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
         rows, token = yield self.get_recent_event_ids_for_room(
-            room_id, limit, end_token, from_token
+            room_id, limit, end_token,
         )
 
         logger.debug("stream before")
         events = yield self._get_events(
-            [r["event_id"] for r in rows],
+            [r.event_id for r in rows],
             get_prev_content=True
         )
         logger.debug("stream after")
@@ -479,59 +369,62 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue((events, token))
 
-    @cached(num_args=4)
-    def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None):
-        end_token = RoomStreamToken.parse_stream_token(end_token)
+    @defer.inlineCallbacks
+    def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+        """Get the most recent events in the room in topological ordering.
 
-        if from_token is None:
-            sql = (
-                "SELECT stream_ordering, topological_ordering, event_id"
-                " FROM events"
-                " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
-                " ORDER BY topological_ordering DESC, stream_ordering DESC"
-                " LIMIT ?"
-            )
-        else:
-            from_token = RoomStreamToken.parse_stream_token(from_token)
-            sql = (
-                "SELECT stream_ordering, topological_ordering, event_id"
-                " FROM events"
-                " WHERE room_id = ? AND stream_ordering > ?"
-                " AND stream_ordering <= ? AND outlier = ?"
-                " ORDER BY topological_ordering DESC, stream_ordering DESC"
-                " LIMIT ?"
-            )
+        Args:
+            room_id (str)
+            limit (int)
+            end_token (str): The stream token representing now.
 
-        def get_recent_events_for_room_txn(txn):
-            if from_token is None:
-                txn.execute(sql, (room_id, end_token.stream, False, limit,))
-            else:
-                txn.execute(sql, (
-                    room_id, from_token.stream, end_token.stream, False, limit
-                ))
+        Returns:
+            Deferred[tuple[list[_EventDictReturn],  str]]: Returns a list of
+            _EventDictReturn and a token pointing to the start of the returned
+            events.
+            The events returned are in ascending order.
+        """
+        # Allow a zero limit here, and no-op.
+        if limit == 0:
+            defer.returnValue(([], end_token))
 
-            rows = self.cursor_to_dict(txn)
+        end_token = RoomStreamToken.parse(end_token)
 
-            rows.reverse()  # As we selected with reverse ordering
+        rows, token = yield self.runInteraction(
+            "get_recent_event_ids_for_room", self._paginate_room_events_txn,
+            room_id, from_token=end_token, limit=limit,
+        )
 
-            if rows:
-                # Tokens are positions between events.
-                # This token points *after* the last event in the chunk.
-                # We need it to point to the event before it in the chunk
-                # since we are going backwards so we subtract one from the
-                # stream part.
-                topo = rows[0]["topological_ordering"]
-                toke = rows[0]["stream_ordering"] - 1
-                start_token = str(RoomStreamToken(topo, toke))
+        # We want to return the results in ascending order.
+        rows.reverse()
 
-                token = (start_token, str(end_token))
-            else:
-                token = (str(end_token), str(end_token))
+        defer.returnValue((rows, token))
+
+    def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
+        """Gets details of the first event in a room at or after a stream ordering
+
+        Args:
+            room_id (str):
+            stream_ordering (int):
 
-            return rows, token
+        Returns:
+            Deferred[(int, int, str)]:
+                (stream ordering, topological ordering, event_id)
+        """
+        def _f(txn):
+            sql = (
+                "SELECT stream_ordering, topological_ordering, event_id"
+                " FROM events"
+                " WHERE room_id = ? AND stream_ordering >= ?"
+                " AND NOT outlier"
+                " ORDER BY stream_ordering"
+                " LIMIT 1"
+            )
+            txn.execute(sql, (room_id, stream_ordering, ))
+            return txn.fetchone()
 
         return self.runInteraction(
-            "get_recent_events_for_room", get_recent_events_for_room_txn
+            "get_room_event_after_stream_ordering", _f,
         )
 
     @defer.inlineCallbacks
@@ -542,7 +435,7 @@ class StreamStore(SQLBaseStore):
         `room_id` causes it to return the current room specific topological
         token.
         """
-        token = yield self._stream_id_gen.get_current_token()
+        token = yield self.get_room_max_stream_ordering()
         if room_id is None:
             defer.returnValue("s%d" % (token,))
         else:
@@ -552,12 +445,6 @@ class StreamStore(SQLBaseStore):
             )
             defer.returnValue("t%d-%d" % (topo, token))
 
-    def get_room_max_stream_ordering(self):
-        return self._stream_id_gen.get_current_token()
-
-    def get_room_min_stream_ordering(self):
-        return self._backfill_id_gen.get_current_token()
-
     def get_stream_token_for_event(self, event_id):
         """The stream token for an event
         Args:
@@ -615,10 +502,20 @@ class StreamStore(SQLBaseStore):
 
     @staticmethod
     def _set_before_and_after(events, rows, topo_order=True):
+        """Inserts ordering information to events' internal metadata from
+        the DB rows.
+
+        Args:
+            events (list[FrozenEvent])
+            rows (list[_EventDictReturn])
+            topo_order (bool): Whether the events were ordered topologically
+                or by stream ordering. If true then all rows should have a non
+                null topological_ordering.
+        """
         for event, row in zip(events, rows):
-            stream = row["stream_ordering"]
-            if topo_order:
-                topo = event.depth
+            stream = row.stream_ordering
+            if topo_order and row.topological_ordering:
+                topo = row.topological_ordering
             else:
                 topo = None
             internal = event.internal_metadata
@@ -690,87 +587,27 @@ class StreamStore(SQLBaseStore):
             retcols=["stream_ordering", "topological_ordering"],
         )
 
-        token = RoomStreamToken(
-            results["topological_ordering"],
+        # Paginating backwards includes the event at the token, but paginating
+        # forward doesn't.
+        before_token = RoomStreamToken(
+            results["topological_ordering"] - 1,
             results["stream_ordering"],
         )
 
-        if isinstance(self.database_engine, Sqlite3Engine):
-            # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
-            # So we give pass it to SQLite3 as the UNION ALL of the two queries.
-
-            query_before = (
-                "SELECT topological_ordering, stream_ordering, event_id FROM events"
-                " WHERE room_id = ? AND topological_ordering < ?"
-                " UNION ALL"
-                " SELECT topological_ordering, stream_ordering, event_id FROM events"
-                " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
-                " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
-            )
-            before_args = (
-                room_id, token.topological,
-                room_id, token.topological, token.stream,
-                before_limit,
-            )
-
-            query_after = (
-                "SELECT topological_ordering, stream_ordering, event_id FROM events"
-                " WHERE room_id = ? AND topological_ordering > ?"
-                " UNION ALL"
-                " SELECT topological_ordering, stream_ordering, event_id FROM events"
-                " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
-                " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
-            )
-            after_args = (
-                room_id, token.topological,
-                room_id, token.topological, token.stream,
-                after_limit,
-            )
-        else:
-            query_before = (
-                "SELECT topological_ordering, stream_ordering, event_id FROM events"
-                " WHERE room_id = ? AND %s"
-                " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
-            ) % (upper_bound(token, self.database_engine, inclusive=False),)
-
-            before_args = (room_id, before_limit)
-
-            query_after = (
-                "SELECT topological_ordering, stream_ordering, event_id FROM events"
-                " WHERE room_id = ? AND %s"
-                " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
-            ) % (lower_bound(token, self.database_engine, inclusive=False),)
-
-            after_args = (room_id, after_limit)
-
-        txn.execute(query_before, before_args)
-
-        rows = self.cursor_to_dict(txn)
-        events_before = [r["event_id"] for r in rows]
-
-        if rows:
-            start_token = str(RoomStreamToken(
-                rows[0]["topological_ordering"],
-                rows[0]["stream_ordering"] - 1,
-            ))
-        else:
-            start_token = str(RoomStreamToken(
-                token.topological,
-                token.stream - 1,
-            ))
-
-        txn.execute(query_after, after_args)
+        after_token = RoomStreamToken(
+            results["topological_ordering"],
+            results["stream_ordering"],
+        )
 
-        rows = self.cursor_to_dict(txn)
-        events_after = [r["event_id"] for r in rows]
+        rows, start_token = self._paginate_room_events_txn(
+            txn, room_id, before_token, direction='b', limit=before_limit,
+        )
+        events_before = [r.event_id for r in rows]
 
-        if rows:
-            end_token = str(RoomStreamToken(
-                rows[-1]["topological_ordering"],
-                rows[-1]["stream_ordering"],
-            ))
-        else:
-            end_token = str(token)
+        rows, end_token = self._paginate_room_events_txn(
+            txn, room_id, after_token, direction='f', limit=after_limit,
+        )
+        events_after = [r.event_id for r in rows]
 
         return {
             "before": {
@@ -832,3 +669,139 @@ class StreamStore(SQLBaseStore):
 
     def has_room_changed_since(self, room_id, stream_id):
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+
+    def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
+                                  direction='b', limit=-1, event_filter=None):
+        """Returns list of events before or after a given token.
+
+        Args:
+            txn
+            room_id (str)
+            from_token (RoomStreamToken): The token used to stream from
+            to_token (RoomStreamToken|None): A token which if given limits the
+                results to only those before
+            direction(char): Either 'b' or 'f' to indicate whether we are
+                paginating forwards or backwards from `from_key`.
+            limit (int): The maximum number of events to return.
+            event_filter (Filter|None): If provided filters the events to
+                those that match the filter.
+
+        Returns:
+            Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
+            as a list of _EventDictReturn and a token that points to the end
+            of the result set.
+        """
+
+        assert int(limit) >= 0
+
+        # Tokens really represent positions between elements, but we use
+        # the convention of pointing to the event before the gap. Hence
+        # we have a bit of asymmetry when it comes to equalities.
+        args = [False, room_id]
+        if direction == 'b':
+            order = "DESC"
+            bounds = upper_bound(
+                from_token, self.database_engine
+            )
+            if to_token:
+                bounds = "%s AND %s" % (bounds, lower_bound(
+                    to_token, self.database_engine
+                ))
+        else:
+            order = "ASC"
+            bounds = lower_bound(
+                from_token, self.database_engine
+            )
+            if to_token:
+                bounds = "%s AND %s" % (bounds, upper_bound(
+                    to_token, self.database_engine
+                ))
+
+        filter_clause, filter_args = filter_to_clause(event_filter)
+
+        if filter_clause:
+            bounds += " AND " + filter_clause
+            args.extend(filter_args)
+
+        args.append(int(limit))
+
+        sql = (
+            "SELECT event_id, topological_ordering, stream_ordering"
+            " FROM events"
+            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
+            " ORDER BY topological_ordering %(order)s,"
+            " stream_ordering %(order)s LIMIT ?"
+        ) % {
+            "bounds": bounds,
+            "order": order,
+        }
+
+        txn.execute(sql, args)
+
+        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+
+        if rows:
+            topo = rows[-1].topological_ordering
+            toke = rows[-1].stream_ordering
+            if direction == 'b':
+                # Tokens are positions between events.
+                # This token points *after* the last event in the chunk.
+                # We need it to point to the event before it in the chunk
+                # when we are going backwards so we subtract one from the
+                # stream part.
+                toke -= 1
+            next_token = RoomStreamToken(topo, toke)
+        else:
+            # TODO (erikj): We should work out what to do here instead.
+            next_token = to_token if to_token else from_token
+
+        return rows, str(next_token),
+
+    @defer.inlineCallbacks
+    def paginate_room_events(self, room_id, from_key, to_key=None,
+                             direction='b', limit=-1, event_filter=None):
+        """Returns list of events before or after a given token.
+
+        Args:
+            room_id (str)
+            from_key (str): The token used to stream from
+            to_key (str|None): A token which if given limits the results to
+                only those before
+            direction(char): Either 'b' or 'f' to indicate whether we are
+                paginating forwards or backwards from `from_key`.
+            limit (int): The maximum number of events to return. Zero or less
+                means no limit.
+            event_filter (Filter|None): If provided filters the events to
+                those that match the filter.
+
+        Returns:
+            tuple[list[dict], str]: Returns the results as a list of dicts and
+            a token that points to the end of the result set. The dicts have
+            the keys "event_id", "topological_ordering" and "stream_orderign".
+        """
+
+        from_key = RoomStreamToken.parse(from_key)
+        if to_key:
+            to_key = RoomStreamToken.parse(to_key)
+
+        rows, token = yield self.runInteraction(
+            "paginate_room_events", self._paginate_room_events_txn,
+            room_id, from_key, to_key, direction, limit, event_filter,
+        )
+
+        events = yield self._get_events(
+            [r.event_id for r in rows],
+            get_prev_content=True
+        )
+
+        self._set_before_and_after(events, rows)
+
+        defer.returnValue((events, token))
+
+
+class StreamStore(StreamWorkerStore):
+    def get_room_max_stream_ordering(self):
+        return self._stream_id_gen.get_current_token()
+
+    def get_room_min_stream_ordering(self):
+        return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bff73f3f04..0f657b2bd3 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,25 +14,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-from twisted.internet import defer
-
-import ujson as json
 import logging
 
-logger = logging.getLogger(__name__)
+from six.moves import range
 
+from canonicaljson import json
 
-class TagsStore(SQLBaseStore):
-    def get_max_account_data_stream_id(self):
-        """Get the current max stream id for the private user data stream
+from twisted.internet import defer
+
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
 
-        Returns:
-            A deferred int.
-        """
-        return self._account_data_id_gen.get_current_token()
 
+class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
     def get_tags_for_user(self, user_id):
         """Get all the tags for a user.
@@ -104,7 +101,7 @@ class TagsStore(SQLBaseStore):
 
         batch_size = 50
         results = []
-        for i in xrange(0, len(tag_ids), batch_size):
+        for i in range(0, len(tag_ids), batch_size):
             tags = yield self.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
@@ -170,6 +167,8 @@ class TagsStore(SQLBaseStore):
             row["tag"]: json.loads(row["content"]) for row in rows
         })
 
+
+class TagsStore(TagsWorkerStore):
     @defer.inlineCallbacks
     def add_tag_to_room(self, user_id, room_id, tag, content):
         """Add a tag to a room for a user.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 809fdd311f..c3bc94f56d 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,17 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+import logging
+from collections import namedtuple
+
+import six
+
+from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 
-from canonicaljson import encode_canonical_json
+from synapse.util.caches.descriptors import cached
 
-from collections import namedtuple
+from ._base import SQLBaseStore
 
-import logging
-import ujson as json
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+    db_binary_type = buffer
+else:
+    db_binary_type = memoryview
 
 logger = logging.getLogger(__name__)
 
@@ -46,8 +54,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, hs):
-        super(TransactionStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(TransactionStore, self).__init__(db_conn, hs)
 
         self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
 
@@ -110,7 +118,7 @@ class TransactionStore(SQLBaseStore):
                 "transaction_id": transaction_id,
                 "origin": origin,
                 "response_code": code,
-                "response_json": buffer(encode_canonical_json(response_dict)),
+                "response_json": db_binary_type(encode_canonical_json(response_dict)),
                 "ts": self._clock.time_msec(),
             },
             or_ignore=True,
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2a4db3f03c..a8781b0e5d 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -13,17 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+import logging
+import re
 
-from ._base import SQLBaseStore
+from six import iteritems
+
+from twisted.internet import defer
 
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-import re
-import logging
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +65,7 @@ class UserDirectoryStore(SQLBaseStore):
             user_ids (list(str)): Users to add
         """
         yield self._simple_insert_many(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             values=[
                 {
                     "user_id": user_id,
@@ -100,7 +102,7 @@ class UserDirectoryStore(SQLBaseStore):
                     user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
                     profile.display_name,
                 )
-                for user_id, profile in users_with_profile.iteritems()
+                for user_id, profile in iteritems(users_with_profile)
             )
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = """
@@ -112,7 +114,7 @@ class UserDirectoryStore(SQLBaseStore):
                     user_id,
                     "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
                 )
-                for user_id, p in users_with_profile.iteritems()
+                for user_id, p in iteritems(users_with_profile)
             )
         else:
             # This should be unreachable.
@@ -130,7 +132,7 @@ class UserDirectoryStore(SQLBaseStore):
                         "display_name": profile.display_name,
                         "avatar_url": profile.avatar_url,
                     }
-                    for user_id, profile in users_with_profile.iteritems()
+                    for user_id, profile in iteritems(users_with_profile)
                 ]
             )
             for user_id in users_with_profile:
@@ -164,7 +166,7 @@ class UserDirectoryStore(SQLBaseStore):
             )
 
             if isinstance(self.database_engine, PostgresEngine):
-                # We weight the loclpart most highly, then display name and finally
+                # We weight the localpart most highly, then display name and finally
                 # server name
                 if new_entry:
                     sql = """
@@ -219,7 +221,7 @@ class UserDirectoryStore(SQLBaseStore):
     @defer.inlineCallbacks
     def update_user_in_public_user_list(self, user_id, room_id):
         yield self._simple_update_one(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             updatevalues={"room_id": room_id},
             desc="update_user_in_public_user_list",
@@ -240,7 +242,7 @@ class UserDirectoryStore(SQLBaseStore):
             )
             self._simple_delete_txn(
                 txn,
-                table="users_in_pubic_room",
+                table="users_in_public_rooms",
                 keyvalues={"user_id": user_id},
             )
             txn.call_after(
@@ -256,18 +258,18 @@ class UserDirectoryStore(SQLBaseStore):
     @defer.inlineCallbacks
     def remove_from_user_in_public_room(self, user_id):
         yield self._simple_delete(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             desc="remove_from_user_in_public_room",
         )
         self.get_user_in_public_room.invalidate((user_id,))
 
     def get_users_in_public_due_to_room(self, room_id):
-        """Get all user_ids that are in the room directory becuase they're
+        """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
         return self._simple_select_onecol(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_public_due_to_room",
@@ -275,7 +277,7 @@ class UserDirectoryStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_users_in_dir_due_to_room(self, room_id):
-        """Get all user_ids that are in the room directory becuase they're
+        """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
         user_ids_dir = yield self._simple_select_onecol(
@@ -286,7 +288,7 @@ class UserDirectoryStore(SQLBaseStore):
         )
 
         user_ids_pub = yield self._simple_select_onecol(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_dir_due_to_room",
@@ -317,6 +319,16 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._execute("get_all_rooms", None, sql)
         defer.returnValue([room_id for room_id, in rows])
 
+    @defer.inlineCallbacks
+    def get_all_local_users(self):
+        """Get all local users
+        """
+        sql = """
+            SELECT name FROM users
+        """
+        rows = yield self._execute("get_all_local_users", None, sql)
+        defer.returnValue([name for name, in rows])
+
     def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
         """Insert entries into the users_who_share_rooms table. The first
         user should be a local user.
@@ -514,7 +526,7 @@ class UserDirectoryStore(SQLBaseStore):
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
-            txn.execute("DELETE FROM users_in_pubic_room")
+            txn.execute("DELETE FROM users_in_public_rooms")
             txn.execute("DELETE FROM users_who_share_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
             txn.call_after(self.get_user_in_public_room.invalidate_all)
@@ -537,7 +549,7 @@ class UserDirectoryStore(SQLBaseStore):
     @cached()
     def get_user_in_public_room(self, user_id):
         return self._simple_select_one(
-            table="users_in_pubic_room",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             retcols=("room_id",),
             allow_none=True,
@@ -629,6 +641,25 @@ class UserDirectoryStore(SQLBaseStore):
                     ]
                 }
         """
+
+        if self.hs.config.user_directory_search_all_users:
+            # make s.user_id null to keep the ordering algorithm happy
+            join_clause = """
+                CROSS JOIN (SELECT NULL as user_id) AS s
+            """
+            join_args = ()
+            where_clause = "1=1"
+        else:
+            join_clause = """
+                LEFT JOIN users_in_public_rooms AS p USING (user_id)
+                LEFT JOIN (
+                    SELECT other_user_id AS user_id FROM users_who_share_rooms
+                    WHERE user_id = ? AND share_private
+                ) AS s USING (user_id)
+            """
+            join_args = (user_id,)
+            where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
 
@@ -638,16 +669,12 @@ class UserDirectoryStore(SQLBaseStore):
             # The array of numbers are the weights for the various part of the
             # search: (domain, _, display name, localpart)
             sql = """
-                SELECT d.user_id, display_name, avatar_url
+                SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
-                LEFT JOIN users_in_pubic_room AS p USING (user_id)
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                %s
                 WHERE
-                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    %s
                     AND vector @@ to_tsquery('english', ?)
                 ORDER BY
                     (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -671,30 +698,26 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """
-            args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+            """ % (join_clause, where_clause)
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
             sql = """
-                SELECT d.user_id, display_name, avatar_url
+                SELECT d.user_id AS user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
-                LEFT JOIN users_in_pubic_room AS p USING (user_id)
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                %s
                 WHERE
-                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    %s
                     AND value MATCH ?
                 ORDER BY
                     rank(matchinfo(user_directory_search)) DESC,
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """
-            args = (user_id, search_query, limit + 1)
+            """ % (join_clause, where_clause)
+            args = join_args + (search_query, limit + 1)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -723,7 +746,7 @@ def _parse_query_sqlite(search_term):
 
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
-    return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
 
 
 def _parse_query_postgres(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
new file mode 100644
index 0000000000..be013f4427
--- /dev/null
+++ b/synapse/storage/user_erasure_store.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import operator
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList
+
+
+class UserErasureWorkerStore(SQLBaseStore):
+    @cached()
+    def is_user_erased(self, user_id):
+        """
+        Check if the given user id has requested erasure
+
+        Args:
+            user_id (str): full user id to check
+
+        Returns:
+            Deferred[bool]: True if the user has requested erasure
+        """
+        return self._simple_select_onecol(
+            table="erased_users",
+            keyvalues={"user_id": user_id},
+            retcol="1",
+            desc="is_user_erased",
+        ).addCallback(operator.truth)
+
+    @cachedList(
+        cached_method_name="is_user_erased",
+        list_name="user_ids",
+        inlineCallbacks=True,
+    )
+    def are_users_erased(self, user_ids):
+        """
+        Checks which users in a list have requested erasure
+
+        Args:
+            user_ids (iterable[str]): full user id to check
+
+        Returns:
+            Deferred[dict[str, bool]]:
+                for each user, whether the user has requested erasure.
+        """
+        # this serves the dual purpose of (a) making sure we can do len and
+        # iterate it multiple times, and (b) avoiding duplicates.
+        user_ids = tuple(set(user_ids))
+
+        def _get_erased_users(txn):
+            txn.execute(
+                "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
+                    ",".join("?" * len(user_ids))
+                ),
+                user_ids,
+            )
+            return set(r[0] for r in txn)
+
+        erased_users = yield self.runInteraction(
+            "are_users_erased", _get_erased_users,
+        )
+        res = dict((u, u in erased_users) for u in user_ids)
+        defer.returnValue(res)
+
+
+class UserErasureStore(UserErasureWorkerStore):
+    def mark_user_erased(self, user_id):
+        """Indicate that user_id wishes their message history to be erased.
+
+        Args:
+            user_id (str): full user_id to be erased
+        """
+        def f(txn):
+            # first check if they are already in the list
+            txn.execute(
+                "SELECT 1 FROM erased_users WHERE user_id = ?",
+                (user_id, )
+            )
+            if txn.fetchone():
+                return
+
+            # they are not already there: do the insert.
+            txn.execute(
+                "INSERT INTO erased_users (user_id) VALUES (?)",
+                (user_id, )
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.is_user_erased, (user_id,)
+            )
+        return self.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 95031dc9ec..d6160d5e4d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,9 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import deque
 import contextlib
 import threading
+from collections import deque
 
 
 class IdGenerator(object):