summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py15
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/deviceinbox.py35
-rw-r--r--synapse/storage/devices.py474
-rw-r--r--synapse/storage/end_to_end_keys.py91
-rw-r--r--synapse/storage/event_federation.py78
-rw-r--r--synapse/storage/events.py397
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/registration.py11
-rw-r--r--synapse/storage/roommember.py30
-rw-r--r--synapse/storage/schema/delta/40/current_state_idx.sql17
-rw-r--r--synapse/storage/schema/delta/40/device_inbox.sql21
-rw-r--r--synapse/storage/schema/delta/40/device_list_streams.sql59
-rw-r--r--synapse/storage/state.py62
-rw-r--r--synapse/storage/stream.py14
15 files changed, 1016 insertions, 303 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index fe936b3e62..b9968debe5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
+        self._device_list_id_gen = StreamIdGenerator(
+            db_conn, "device_lists_stream", "stream_id",
+        )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@@ -189,7 +192,8 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "device_inbox",
             entity_column="user_id",
             stream_column="stream_id",
-            max_value=max_device_inbox_id
+            max_value=max_device_inbox_id,
+            limit=1000,
         )
         self._device_inbox_stream_cache = StreamChangeCache(
             "DeviceInboxStreamChangeCache", min_device_inbox_id,
@@ -202,12 +206,21 @@ class DataStore(RoomMemberStore, RoomStore,
             entity_column="destination",
             stream_column="stream_id",
             max_value=max_device_inbox_id,
+            limit=1000,
         )
         self._device_federation_outbox_stream_cache = StreamChangeCache(
             "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
             prefilled_cache=device_outbox_prefill,
         )
 
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b62c459d8b..05374682fd 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -169,7 +169,7 @@ class SQLBaseStore(object):
                                       max_entries=hs.config.event_cache_size)
 
         self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
+            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
         )
 
         self._event_fetch_lock = threading.Condition()
@@ -387,6 +387,10 @@ class SQLBaseStore(object):
         Args:
             table : string giving the table name
             values : dict of new column names and values for them
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
         """
         try:
             yield self.runInteraction(
@@ -398,6 +402,8 @@ class SQLBaseStore(object):
             # a cursor after we receive an error from the db.
             if not or_ignore:
                 raise
+            defer.returnValue(False)
+        defer.returnValue(True)
 
     @staticmethod
     def _simple_insert_txn(txn, table, values):
@@ -838,18 +844,19 @@ class SQLBaseStore(object):
         return txn.execute(sql, keyvalues.values())
 
     def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
-                        max_value):
+                        max_value, limit=100000):
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
         sql = (
             "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
-            " WHERE %(stream)s > ? - 100000"
+            " WHERE %(stream)s > ? - %(limit)s"
             " GROUP BY %(entity)s"
         ) % {
             "table": table,
             "entity": entity_column,
             "stream": stream_column,
+            "limit": limit,
         }
 
         sql = self.database_engine.convert_param_style(sql)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 2821eb89c9..bde3b5cbbc 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -18,13 +18,29 @@ import ujson
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from .background_updates import BackgroundUpdateStore
 
 
 logger = logging.getLogger(__name__)
 
 
-class DeviceInboxStore(SQLBaseStore):
+class DeviceInboxStore(BackgroundUpdateStore):
+    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+    def __init__(self, hs):
+        super(DeviceInboxStore, self).__init__(hs)
+
+        self.register_background_index_update(
+            "device_inbox_stream_index",
+            index_name="device_inbox_stream_id_user_id",
+            table="device_inbox",
+            columns=["stream_id", "user_id"],
+        )
+
+        self.register_background_update_handler(
+            self.DEVICE_INBOX_STREAM_ID,
+            self._background_drop_index_device_inbox,
+        )
 
     @defer.inlineCallbacks
     def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
             "delete_device_msgs_for_remote",
             delete_messages_for_remote_destination_txn
         )
+
+    @defer.inlineCallbacks
+    def _background_drop_index_device_inbox(self, progress, batch_size):
+        def reindex_txn(conn):
+            txn = conn.cursor()
+            txn.execute(
+                "DROP INDEX IF EXISTS device_inbox_stream_id"
+            )
+            txn.close()
+
+        yield self.runWithConnection(reindex_txn)
+
+        yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+        defer.returnValue(1)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 17920d4480..8e17800364 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import ujson as json
 
 from twisted.internet import defer
 
@@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(DeviceStore, self).__init__(hs)
+
+        self._clock.looping_call(
+            self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+        )
+
     @defer.inlineCallbacks
     def store_device(self, user_id, device_id,
-                     initial_device_display_name,
-                     ignore_if_known=True):
+                     initial_device_display_name):
         """Ensure the given device is known; add it to the store if not
 
         Args:
             user_id (str): id of user associated with the device
             device_id (str): id of device
             initial_device_display_name (str): initial displayname of the
-               device
-            ignore_if_known (bool): ignore integrity errors which mean the
-               device is already known
+               device. Ignored if device exists.
         Returns:
-            defer.Deferred
-        Raises:
-            StoreError: if ignore_if_known is False and the device was already
-               known
+            defer.Deferred: boolean whether the device was inserted or an
+                existing device existed with that ID.
         """
         try:
-            yield self._simple_insert(
+            inserted = yield self._simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
                     "display_name": initial_device_display_name
                 },
                 desc="store_device",
-                or_ignore=ignore_if_known,
+                or_ignore=True,
             )
+            defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
                          " display_name=%s(%r) failed: %s",
@@ -139,3 +143,451 @@ class DeviceStore(SQLBaseStore):
         )
 
         defer.returnValue({d["device_id"]: d for d in devices})
+
+    def get_device_list_last_stream_id_for_remote(self, user_id):
+        """Get the last stream_id we got for a user. May be None if we haven't
+        got any information for them.
+        """
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_remote_extremity",
+            allow_none=True,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
+        return self._simple_delete(
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            desc="mark_remote_user_device_list_as_unsubscribed",
+        )
+
+    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+                                              stream_id):
+        """Updates a single user's device in the cache.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache_entry",
+            self._update_remote_device_list_cache_entry_txn,
+            user_id, device_id, content, stream_id,
+        )
+
+    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+                                                   content, stream_id):
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "content": json.dumps(content),
+            }
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+        """Replace the cache of the remote user's devices.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache",
+            self._update_remote_device_list_cache_txn,
+            user_id, devices, stream_id,
+        )
+
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+                                             stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_remote_cache",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": content["device_id"],
+                    "content": json.dumps(content),
+                }
+                for content in devices
+            ]
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def get_devices_by_remote(self, destination, from_stream_id):
+        """Get stream of updates to send to remote servers
+
+        Returns:
+            (now_stream_id, [ { updates }, .. ])
+        """
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+            destination, int(from_stream_id)
+        )
+        if not has_changed:
+            return (now_stream_id, [])
+
+        return self.runInteraction(
+            "get_devices_by_remote", self._get_devices_by_remote_txn,
+            destination, from_stream_id, now_stream_id,
+        )
+
+    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+                                   now_stream_id):
+        sql = """
+            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+            GROUP BY user_id, device_id
+        """
+        txn.execute(
+            sql, (destination, from_stream_id, now_stream_id, False)
+        )
+        rows = txn.fetchall()
+
+        if not rows:
+            return (now_stream_id, [])
+
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in rows}
+        devices = self._get_e2e_device_keys_txn(
+            txn, query_map.keys(), include_all_devices=True
+        )
+
+        prev_sent_id_sql = """
+            SELECT coalesce(max(stream_id), 0) as stream_id
+            FROM device_lists_outbound_pokes
+            WHERE destination = ? AND user_id = ? AND stream_id <= ?
+        """
+
+        results = []
+        for user_id, user_devices in devices.iteritems():
+            # The prev_id for the first row is always the last row before
+            # `from_stream_id`
+            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+            rows = txn.fetchall()
+            prev_id = rows[0][0]
+            for device_id, device in user_devices.iteritems():
+                stream_id = query_map[(user_id, device_id)]
+                result = {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "prev_id": [prev_id] if prev_id else [],
+                    "stream_id": stream_id,
+                }
+
+                prev_id = stream_id
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+        return (now_stream_id, results)
+
+    def get_user_devices_from_cache(self, query_list):
+        """Get the devices (and keys if any) for remote users from the cache.
+
+        Args:
+            query_list(list): List of (user_id, device_ids), if device_ids is
+                falsey then return all device ids for that user.
+
+        Returns:
+            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+            a set of user_ids and results_map is a mapping of
+            user_id -> device_id -> device_info
+        """
+        return self.runInteraction(
+            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+            query_list,
+        )
+
+    def _get_user_devices_from_cache_txn(self, txn, query_list):
+        user_ids = {user_id for user_id, _ in query_list}
+
+        user_ids_in_cache = set()
+        for user_id in user_ids:
+            stream_ids = self._simple_select_onecol_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={
+                    "user_id": user_id,
+                },
+                retcol="stream_id",
+            )
+            if stream_ids:
+                user_ids_in_cache.add(user_id)
+
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                content = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
+                    retcol="content",
+                )
+                results.setdefault(user_id, {})[device_id] = json.loads(content)
+            else:
+                devices = self._simple_select_list_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                    },
+                    retcols=("device_id", "content"),
+                )
+                results[user_id] = {
+                    device["device_id"]: json.loads(device["content"])
+                    for device in devices
+                }
+                user_ids_in_cache.discard(user_id)
+
+        return user_ids_not_in_cache, results
+
+    def get_devices_with_keys_by_user(self, user_id):
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.iteritems():
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
+    def mark_as_sent_devices_by_remote(self, destination, stream_id):
+        """Mark that updates have successfully been sent to the destination.
+        """
+        return self.runInteraction(
+            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+            destination, stream_id,
+        )
+
+    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+        # First we DELETE all rows such that only the latest row for each
+        # (destination, user_id is left. We do this by selecting first and
+        # deleting.
+        sql = """
+            SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND stream_id <= ?
+            GROUP BY user_id
+            HAVING count(*) > 1
+        """
+        txn.execute(sql, (destination, stream_id,))
+        rows = txn.fetchall()
+
+        sql = """
+            DELETE FROM device_lists_outbound_pokes
+            WHERE destination = ? AND user_id = ? AND stream_id < ?
+        """
+        txn.executemany(
+            sql, ((destination, row[0], row[1],) for row in rows)
+        )
+
+        # Mark everything that is left as sent
+        sql = """
+            UPDATE device_lists_outbound_pokes SET sent = ?
+            WHERE destination = ? AND stream_id <= ?
+        """
+        txn.execute(sql, (True, destination, stream_id,))
+
+    @defer.inlineCallbacks
+    def get_user_whose_devices_changed(self, from_key):
+        """Get set of users whose devices have changed since `from_key`.
+        """
+        from_key = int(from_key)
+        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+        if changed is not None:
+            defer.returnValue(set(changed))
+
+        sql = """
+            SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+        """
+        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        defer.returnValue(set(row[0] for row in rows))
+
+    def get_all_device_list_changes_for_remotes(self, from_key):
+        """Return a list of `(stream_id, user_id, destination)` which is the
+        combined list of changes to devices, and which destinations need to be
+        poked. `destination` may be None if no destinations need to be poked.
+        """
+        sql = """
+            SELECT stream_id, user_id, destination FROM device_lists_stream
+            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            WHERE stream_id > ?
+        """
+        return self._execute(
+            "get_users_and_hosts_device_list", None,
+            sql, from_key,
+        )
+
+    @defer.inlineCallbacks
+    def add_device_change_to_streams(self, user_id, device_ids, hosts):
+        """Persist that a user's devices have been updated, and which hosts
+        (if any) should be poked.
+        """
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_device_change_to_streams", self._add_device_change_txn,
+                user_id, device_ids, hosts, stream_id,
+            )
+        defer.returnValue(stream_id)
+
+    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
+        now = self._clock.time_msec()
+
+        txn.call_after(
+            self._device_list_stream_cache.entity_has_changed,
+            user_id, stream_id,
+        )
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host, stream_id,
+            )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_stream",
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+                for device_id in device_ids
+            ]
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_outbound_pokes",
+            values=[
+                {
+                    "destination": destination,
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "sent": False,
+                    "ts": now,
+                }
+                for destination in hosts
+                for device_id in device_ids
+            ]
+        )
+
+    def get_device_stream_token(self):
+        return self._device_list_id_gen.get_current_token()
+
+    def _prune_old_outbound_device_pokes(self):
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers. We keep one entry per
+        (destination, user_id) tuple to ensure that the prev_ids remain correct
+        if the server does come back.
+        """
+        yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+
+        def _prune_txn(txn):
+            select_sql = """
+                SELECT destination, user_id, max(stream_id) as stream_id
+                FROM device_lists_outbound_pokes
+                GROUP BY destination, user_id
+                HAVING min(ts) < ? AND count(*) > 1
+            """
+
+            txn.execute(select_sql, (yesterday,))
+            rows = txn.fetchall()
+
+            if not rows:
+                return
+
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+            """
+
+            txn.executemany(
+                delete_sql,
+                (
+                    (yesterday, row[0], row[1], row[2])
+                    for row in rows
+                )
+            )
+
+            logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+
+        return self.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn
+        )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 385d607056..2040e022fa 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,74 +12,111 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections
+from twisted.internet import defer
 
-import twisted.internet.defer
+from canonicaljson import encode_canonical_json
+import ujson as json
 
 from ._base import SQLBaseStore
 
 
 class EndToEndKeyStore(SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
-        return self._simple_upsert(
-            table="e2e_device_keys_json",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "ts_added_ms": time_now,
-                "key_json": json_bytes,
-            }
+    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+        def _set_e2e_device_keys_txn(txn):
+            old_key_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            new_key_json = encode_canonical_json(device_keys)
+            if old_key_json == new_key_json:
+                return False
+
+            self._simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "ts_added_ms": time_now,
+                    "key_json": new_key_json,
+                }
+            )
+
+            return True
+
+        return self.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
         )
 
-    def get_e2e_device_keys(self, query_list):
+    @defer.inlineCallbacks
+    def get_e2e_device_keys(self, query_list, include_all_devices=False):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
+            include_all_devices (bool): whether to include entries for devices
+                that don't have device keys
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
         """
         if not query_list:
-            return {}
+            defer.returnValue({})
 
-        return self.runInteraction(
-            "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+        results = yield self.runInteraction(
+            "get_e2e_device_keys", self._get_e2e_device_keys_txn,
+            query_list, include_all_devices,
         )
 
-    def _get_e2e_device_keys_txn(self, txn, query_list):
+        for user_id, device_keys in results.iteritems():
+            for device_id, device_info in device_keys.iteritems():
+                device_info["keys"] = json.loads(device_info.pop("key_json"))
+
+        defer.returnValue(results)
+
+    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
         query_clauses = []
         query_params = []
 
         for (user_id, device_id) in query_list:
-            query_clause = "k.user_id = ?"
+            query_clause = "user_id = ?"
             query_params.append(user_id)
 
             if device_id:
-                query_clause += " AND k.device_id = ?"
+                query_clause += " AND device_id = ?"
                 query_params.append(device_id)
 
             query_clauses.append(query_clause)
 
         sql = (
-            "SELECT k.user_id, k.device_id, "
+            "SELECT user_id, device_id, "
             "    d.display_name AS device_display_name, "
             "    k.key_json"
-            " FROM e2e_device_keys_json k"
-            "    LEFT JOIN devices d ON d.user_id = k.user_id"
-            "      AND d.device_id = k.device_id"
+            " FROM devices d"
+            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
             " WHERE %s"
         ) % (
+            "LEFT" if include_all_devices else "INNER",
             " OR ".join("(" + q + ")" for q in query_clauses)
         )
 
         txn.execute(sql, query_params)
         rows = self.cursor_to_dict(txn)
 
-        result = collections.defaultdict(dict)
+        result = {}
         for row in rows:
-            result[row["user_id"]][row["device_id"]] = row
+            result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
         return result
 
@@ -152,7 +189,7 @@ class EndToEndKeyStore(SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    @twisted.internet.defer.inlineCallbacks
+    @defer.inlineCallbacks
     def delete_e2e_keys_by_device(self, user_id, device_id):
         yield self._simple_delete(
             table="e2e_device_keys_json",
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 53feaa1960..ee88c61954 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
-    @cached()
+    @cached(max_entries=5000, iterable=True)
     def get_latest_event_ids_in_room(self, room_id):
         return self._simple_select_onecol(
             table="event_forward_extremities",
@@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore):
             ],
         )
 
-        self._update_extremeties(txn, events)
+        self._update_backward_extremeties(txn, events)
 
-    def _update_extremeties(self, txn, events):
-        """Updates the event_*_extremities tables based on the new/updated
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
         This is called for new events *and* for events that were outliers, but
-        are are now being persisted as non-outliers.
+        are now being persisted as non-outliers.
+
+        Forward extremities are handled when we first start persisting the events.
         """
         events_by_room = {}
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
-        for room_id, room_events in events_by_room.items():
-            prevs = [
-                e_id for ev in room_events for e_id, _ in ev.prev_events
-                if not ev.internal_metadata.is_outlier()
-            ]
-            if prevs:
-                txn.execute(
-                    "DELETE FROM event_forward_extremities"
-                    " WHERE room_id = ?"
-                    " AND event_id in (%s)" % (
-                        ",".join(["?"] * len(prevs)),
-                    ),
-                    [room_id] + prevs,
-                )
-
-        query = (
-            "INSERT INTO event_forward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_edges WHERE prev_event_id = ?"
-            " )"
-        )
-
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id, ev.event_id) for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ]
-        )
-
-        # We now insert into stream_ordering_to_exterm a mapping from room_id,
-        # new stream_ordering to new forward extremeties in the room.
-        # This allows us to later efficiently look up the forward extremeties
-        # for a room before a given stream_ordering
-        max_stream_ord = max(
-            ev.internal_metadata.stream_ordering for ev in events
-        )
-        new_extrem = {}
-        for room_id in events_by_room:
-            event_ids = self._simple_select_onecol_txn(
-                txn,
-                table="event_forward_extremities",
-                keyvalues={"room_id": room_id},
-                retcol="event_id",
-            )
-            new_extrem[room_id] = event_ids
-
-        self._simple_insert_many_txn(
-            txn,
-            table="stream_ordering_to_exterm",
-            values=[
-                {
-                    "room_id": room_id,
-                    "event_id": event_id,
-                    "stream_ordering": max_stream_ord,
-                }
-                for room_id, extrem_evs in new_extrem.items()
-                for event_id in extrem_evs
-            ]
-        )
-
         query = (
             "INSERT INTO event_backward_extremities (event_id, room_id)"
             " SELECT ?, ? WHERE NOT EXISTS ("
@@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore):
             ]
         )
 
-        for room_id in events_by_room:
-            txn.call_after(
-                self.get_latest_event_ids_in_room.invalidate, (room_id,)
-            )
-
     def get_forward_extremeties_for_room(self, room_id, stream_ordering):
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 04dbdac3f8..8659f605a5 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore, _RollbackButIsFineException
+from ._base import SQLBaseStore
 
 from twisted.internet import defer, reactor
 
@@ -27,6 +27,8 @@ from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
+from synapse.state import resolve_events
+from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
@@ -71,22 +73,19 @@ class _EventPeristenceQueue(object):
     """
 
     _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
-        "events_and_contexts", "current_state", "backfilled", "deferred",
+        "events_and_contexts", "backfilled", "deferred",
     ))
 
     def __init__(self):
         self._event_persist_queues = {}
         self._currently_persisting_rooms = set()
 
-    def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
+    def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
             end_item = queue[-1]
-            if end_item.current_state or current_state:
-                # We perist events with current_state set to True one at a time
-                pass
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
@@ -96,7 +95,6 @@ class _EventPeristenceQueue(object):
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
             backfilled=backfilled,
-            current_state=current_state,
             deferred=deferred,
         ))
 
@@ -216,7 +214,6 @@ class EventsStore(SQLBaseStore):
             d = preserve_fn(self._event_persist_queue.add_to_queue)(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
-                current_state=None,
             )
             deferreds.append(d)
 
@@ -229,11 +226,10 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, context, current_state=None, backfilled=False):
+    def persist_event(self, event, context, backfilled=False):
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
-            current_state=current_state,
         )
 
         self._maybe_start_persisting(event.room_id)
@@ -246,21 +242,10 @@ class EventsStore(SQLBaseStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            if item.current_state:
-                for event, context in item.events_and_contexts:
-                    # There should only ever be one item in
-                    # events_and_contexts when current_state is
-                    # not None
-                    yield self._persist_event(
-                        event, context,
-                        current_state=item.current_state,
-                        backfilled=item.backfilled,
-                    )
-            else:
-                yield self._persist_events(
-                    item.events_and_contexts,
-                    backfilled=item.backfilled,
-                )
+            yield self._persist_events(
+                item.events_and_contexts,
+                backfilled=item.backfilled,
+            )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
@@ -294,35 +279,183 @@ class EventsStore(SQLBaseStore):
             for chunk in chunks:
                 # We can't easily parallelize these since different chunks
                 # might contain the same event. :(
+
+                # NB: Assumes that we are only persisting events for one room
+                # at a time.
+                new_forward_extremeties = {}
+                current_state_for_room = {}
+                if not backfilled:
+                    with Measure(self._clock, "_calculate_state_and_extrem"):
+                        # Work out the new "current state" for each room.
+                        # We do this by working out what the new extremities are and then
+                        # calculating the state from that.
+                        events_by_room = {}
+                        for event, context in chunk:
+                            events_by_room.setdefault(event.room_id, []).append(
+                                (event, context)
+                            )
+
+                        for room_id, ev_ctx_rm in events_by_room.items():
+                            # Work out new extremities by recursively adding and removing
+                            # the new events.
+                            latest_event_ids = yield self.get_latest_event_ids_in_room(
+                                room_id
+                            )
+                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                                room_id, [ev for ev, _ in ev_ctx_rm]
+                            )
+
+                            if new_latest_event_ids == set(latest_event_ids):
+                                # No change in extremities, so no change in state
+                                continue
+
+                            new_forward_extremeties[room_id] = new_latest_event_ids
+
+                            state = yield self._calculate_state_delta(
+                                room_id, ev_ctx_rm, new_latest_event_ids
+                            )
+                            if state:
+                                current_state_for_room[room_id] = state
+
                 yield self.runInteraction(
                     "persist_events",
                     self._persist_events_txn,
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
+                    current_state_for_room=current_state_for_room,
+                    new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
 
-    @_retry_on_integrity_error
     @defer.inlineCallbacks
-    @log_function
-    def _persist_event(self, event, context, current_state=None, backfilled=False,
-                       delete_existing=False):
-        try:
-            with self._stream_id_gen.get_next() as stream_ordering:
-                event.internal_metadata.stream_ordering = stream_ordering
-                yield self.runInteraction(
-                    "persist_event",
-                    self._persist_event_txn,
-                    event=event,
-                    context=context,
-                    current_state=current_state,
-                    backfilled=backfilled,
-                    delete_existing=delete_existing,
-                )
-                persist_event_counter.inc()
-        except _RollbackButIsFineException:
-            pass
+    def _calculate_new_extremeties(self, room_id, events):
+        """Calculates the new forward extremeties for a room given events to
+        persist.
+
+        Assumes that we are only persisting events for one room at a time.
+        """
+        latest_event_ids = yield self.get_latest_event_ids_in_room(
+            room_id
+        )
+        new_latest_event_ids = set(latest_event_ids)
+        # First, add all the new events to the list
+        new_latest_event_ids.update(
+            event.event_id for event in events
+            if not event.internal_metadata.is_outlier()
+        )
+        # Now remove all events that are referenced by the to-be-added events
+        new_latest_event_ids.difference_update(
+            e_id
+            for event in events
+            for e_id, _ in event.prev_events
+            if not event.internal_metadata.is_outlier()
+        )
+
+        # And finally remove any events that are referenced by previously added
+        # events.
+        rows = yield self._simple_select_many_batch(
+            table="event_edges",
+            column="prev_event_id",
+            iterable=list(new_latest_event_ids),
+            retcols=["prev_event_id"],
+            keyvalues={
+                "room_id": room_id,
+                "is_state": False,
+            },
+            desc="_calculate_new_extremeties",
+        )
+
+        new_latest_event_ids.difference_update(
+            row["prev_event_id"] for row in rows
+        )
+
+        defer.returnValue(new_latest_event_ids)
+
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
+            (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+            May return None if there are no changes to be applied.
+        """
+        # Now we need to work out the different state sets for
+        # each state extremities
+        state_sets = []
+        missing_event_ids = []
+        was_updated = False
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding,
+            # and then use the current state from that
+            for ev, ctx in events_context:
+                if event_id == ev.event_id:
+                    if ctx.current_state_ids is None:
+                        raise Exception("Unknown current state")
+                    state_sets.append(ctx.current_state_ids)
+                    if ctx.delta_ids or hasattr(ev, "state_key"):
+                        was_updated = True
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                was_updated = True
+                missing_event_ids.append(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state for any missing events from DB
+            event_to_groups = yield self._get_state_group_for_events(
+                missing_event_ids,
+            )
+
+            groups = set(event_to_groups.values())
+            group_to_state = yield self._get_state_for_groups(groups)
+
+            state_sets.extend(group_to_state.values())
+
+        if not new_latest_event_ids:
+            current_state = {}
+        elif was_updated:
+            current_state = yield resolve_events(
+                state_sets,
+                state_map_factory=lambda ev_ids: self.get_events(
+                    ev_ids, get_prev_content=False, check_redacted=False,
+                ),
+            )
+        else:
+            return
+
+        existing_state_rows = yield self._simple_select_list(
+            table="current_state_events",
+            keyvalues={"room_id": room_id},
+            retcols=["event_id", "type", "state_key"],
+            desc="_calculate_state_delta",
+        )
+
+        existing_events = set(row["event_id"] for row in existing_state_rows)
+        new_events = set(ev_id for ev_id in current_state.itervalues())
+        changed_events = existing_events ^ new_events
+
+        if not changed_events:
+            return
+
+        to_delete = {
+            (row["type"], row["state_key"]): row["event_id"]
+            for row in existing_state_rows
+            if row["event_id"] in changed_events
+        }
+        events_to_insert = (new_events - existing_events)
+        to_insert = {
+            key: ev_id for key, ev_id in current_state.iteritems()
+            if ev_id in events_to_insert
+        }
+
+        defer.returnValue((to_delete, to_insert))
 
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
@@ -381,52 +514,9 @@ class EventsStore(SQLBaseStore):
         defer.returnValue({e.event_id: e for e in events})
 
     @log_function
-    def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
-                           delete_existing=False):
-        # We purposefully do this first since if we include a `current_state`
-        # key, we *want* to update the `current_state_events` table
-        if current_state:
-            txn.call_after(self._get_current_state_for_key.invalidate_all)
-            txn.call_after(self.get_rooms_for_user.invalidate_all)
-            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
-
-            # Add an entry to the current_state_resets table to record the point
-            # where we clobbered the current state
-            stream_order = event.internal_metadata.stream_ordering
-            self._simple_insert_txn(
-                txn,
-                table="current_state_resets",
-                values={"event_stream_ordering": stream_order}
-            )
-
-            self._simple_delete_txn(
-                txn,
-                table="current_state_events",
-                keyvalues={"room_id": event.room_id},
-            )
-
-            for s in current_state:
-                self._simple_insert_txn(
-                    txn,
-                    "current_state_events",
-                    {
-                        "event_id": s.event_id,
-                        "room_id": s.room_id,
-                        "type": s.type,
-                        "state_key": s.state_key,
-                    }
-                )
-
-        return self._persist_events_txn(
-            txn,
-            [(event, context)],
-            backfilled=backfilled,
-            delete_existing=delete_existing,
-        )
-
-    @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False):
+                            delete_existing=False, current_state_for_room={},
+                            new_forward_extremeties={}):
         """Insert some number of room events into the necessary database tables.
 
         Rejected events are only inserted into the events table, the events_json table,
@@ -436,6 +526,93 @@ class EventsStore(SQLBaseStore):
         If delete_existing is True then existing events will be purged from the
         database before insertion. This is useful when retrying due to IntegrityError.
         """
+        max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+        for room_id, current_state_tuple in current_state_for_room.iteritems():
+                to_delete, to_insert = current_state_tuple
+                txn.executemany(
+                    "DELETE FROM current_state_events WHERE event_id = ?",
+                    [(ev_id,) for ev_id in to_delete.itervalues()],
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="current_state_events",
+                    values=[
+                        {
+                            "event_id": ev_id,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                        }
+                        for key, ev_id in to_insert.iteritems()
+                    ],
+                )
+
+                # Invalidate the various caches
+
+                # Figure out the changes of membership to invalidate the
+                # `get_rooms_for_user` cache.
+                # We find out which membership events we may have deleted
+                # and which we have added, then we invlidate the caches for all
+                # those users.
+                members_changed = set(
+                    state_key for ev_type, state_key in to_delete.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+                members_changed.update(
+                    state_key for ev_type, state_key in to_insert.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+
+                for member in members_changed:
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_rooms_for_user, (member,)
+                    )
+
+                self._invalidate_cache_and_stream(
+                    txn, self.get_users_in_room, (room_id,)
+                )
+
+        for room_id, new_extrem in new_forward_extremeties.items():
+            self._simple_delete_txn(
+                txn,
+                table="event_forward_extremities",
+                keyvalues={"room_id": room_id},
+            )
+            txn.call_after(
+                self.get_latest_event_ids_in_room.invalidate, (room_id,)
+            )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="event_forward_extremities",
+            values=[
+                {
+                    "event_id": ev_id,
+                    "room_id": room_id,
+                }
+                for room_id, new_extrem in new_forward_extremeties.items()
+                for ev_id in new_extrem
+            ],
+        )
+        # We now insert into stream_ordering_to_exterm a mapping from room_id,
+        # new stream_ordering to new forward extremeties in the room.
+        # This allows us to later efficiently look up the forward extremeties
+        # for a room before a given stream_ordering
+        self._simple_insert_many_txn(
+            txn,
+            table="stream_ordering_to_exterm",
+            values=[
+                {
+                    "room_id": room_id,
+                    "event_id": event_id,
+                    "stream_ordering": max_stream_order,
+                }
+                for room_id, new_extrem in new_forward_extremeties.items()
+                for event_id in new_extrem
+            ]
+        )
+
         # Ensure that we don't have the same event twice.
         # Pick the earliest non-outlier if there is one, else the earliest one.
         new_events_and_contexts = OrderedDict()
@@ -550,7 +727,7 @@ class EventsStore(SQLBaseStore):
 
                 # Update the event_backward_extremities table now that this
                 # event isn't an outlier any more.
-                self._update_extremeties(txn, [event])
+                self._update_backward_extremeties(txn, [event])
 
         events_and_contexts = [
             ec for ec in events_and_contexts if ec[0] not in to_remove
@@ -798,29 +975,6 @@ class EventsStore(SQLBaseStore):
             # to update the current state table
             return
 
-        for event, _ in state_events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                # Outlier events shouldn't clobber the current state.
-                continue
-
-            txn.call_after(
-                self._get_current_state_for_key.invalidate,
-                (event.room_id, event.type, event.state_key,)
-            )
-
-            self._simple_upsert_txn(
-                txn,
-                "current_state_events",
-                keyvalues={
-                    "room_id": event.room_id,
-                    "type": event.type,
-                    "state_key": event.state_key,
-                },
-                values={
-                    "event_id": event.event_id,
-                }
-            )
-
         return
 
     def _add_to_cache(self, txn, events_and_contexts):
@@ -1084,10 +1238,10 @@ class EventsStore(SQLBaseStore):
                     self._do_fetch
                 )
 
-        logger.info("Loading %d events", len(events))
+        logger.debug("Loading %d events", len(events))
         with PreserveLoggingContext():
             rows = yield events_d
-        logger.info("Loaded %d events (%d rows)", len(events), len(rows))
+        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
 
         if not allow_rejected:
             rows[:] = [r for r in rows if not r["rejects"]]
@@ -1418,6 +1572,7 @@ class EventsStore(SQLBaseStore):
         """The current minimum token that backfilled events have reached"""
         return -self._backfill_id_gen.get_current_token()
 
+    @cached(num_args=5, max_entries=10)
     def get_all_new_events(self, last_backfill_id, last_forward_id,
                            current_backfill_id, current_forward_id, limit):
         """Get all the new events that have arrived at the server either as
@@ -1450,15 +1605,6 @@ class EventsStore(SQLBaseStore):
                     upper_bound = current_forward_id
 
                 sql = (
-                    "SELECT event_stream_ordering FROM current_state_resets"
-                    " WHERE ? < event_stream_ordering"
-                    " AND event_stream_ordering <= ?"
-                    " ORDER BY event_stream_ordering ASC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                state_resets = txn.fetchall()
-
-                sql = (
                     "SELECT event_stream_ordering, event_id, state_group"
                     " FROM ex_outlier_stream"
                     " WHERE ? > event_stream_ordering"
@@ -1469,7 +1615,6 @@ class EventsStore(SQLBaseStore):
                 forward_ex_outliers = txn.fetchall()
             else:
                 new_forward_events = []
-                state_resets = []
                 forward_ex_outliers = []
 
             sql = (
@@ -1509,7 +1654,6 @@ class EventsStore(SQLBaseStore):
             return AllNewEventsResult(
                 new_forward_events, new_backfill_events,
                 forward_ex_outliers, backward_ex_outliers,
-                state_resets,
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
@@ -1735,5 +1879,4 @@ class EventsStore(SQLBaseStore):
 AllNewEventsResult = namedtuple("AllNewEventsResult", [
     "new_forward_events", "new_backfill_events",
     "forward_ex_outliers", "backward_ex_outliers",
-    "state_resets"
 ])
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e46ae6502e..b357f22be7 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 39
+SCHEMA_VERSION = 40
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 983a8ec52b..26be6060c3 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             desc="user_delete_threepids",
         )
 
+    def user_delete_threepid(self, user_id, medium, address):
+        return self._simple_delete(
+            "user_threepids",
+            keyvalues={
+                "user_id": user_id,
+                "medium": medium,
+                "address": address,
+            },
+            desc="user_delete_threepids",
+        )
+
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 5d18037c7c..545d3d3a99 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore):
         )
 
         for event in events:
-            txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
-            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
                 event.state_key, event.internal_metadata.stream_ordering
@@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore):
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
-    @cached(max_entries=5000)
+    @cached(max_entries=500000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
 
@@ -220,7 +218,7 @@ class RoomMemberStore(SQLBaseStore):
                 " ON e.event_id = c.event_id"
                 " AND m.room_id = c.room_id"
                 " AND m.user_id = c.state_key"
-                " WHERE %s"
+                " WHERE c.type = 'm.room.member' AND %s"
             ) % (where_clause,)
 
             txn.execute(sql, args)
@@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore):
             " ON m.event_id = c.event_id "
             " AND m.room_id = c.room_id "
             " AND m.user_id = c.state_key"
-            " WHERE %(where)s"
+            " WHERE c.type = 'm.room.member' AND %(where)s"
         ) % {
             "where": where_clause,
         }
@@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore):
 
         return rows
 
-    @cached(max_entries=5000)
+    @cached(max_entries=500000, iterable=True)
     def get_rooms_for_user(self, user_id):
         return self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
 
+    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+    def get_users_who_share_room_with_user(self, user_id, cache_context):
+        """Returns the set of users who share a room with `user_id`
+        """
+        rooms = yield self.get_rooms_for_user(
+            user_id, on_invalidate=cache_context.invalidate,
+        )
+
+        user_who_share_room = set()
+        for room in rooms:
+            user_ids = yield self.get_users_in_room(
+                room.room_id, on_invalidate=cache_context.invalidate,
+            )
+            user_who_share_room.update(user_ids)
+
+        defer.returnValue(user_who_share_room)
+
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
@@ -390,7 +405,8 @@ class RoomMemberStore(SQLBaseStore):
             room_id, state_group, state_ids,
         )
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
+                           max_entries=100000)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
                                        cache_context, event=None):
         # We don't use `state_group`, it's there so that we can cache based
diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql
new file mode 100644
index 0000000000..7ffa189f39
--- /dev/null
+++ b/synapse/storage/schema/delta/40/current_state_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('current_state_members_idx', '{}');
diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql
new file mode 100644
index 0000000000..b9fe1f0480
--- /dev/null
+++ b/synapse/storage/schema/delta/40/device_inbox.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- turn the pre-fill startup query into a index-only scan on postgresql.
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('device_inbox_stream_index', '{}');
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+    VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql
new file mode 100644
index 0000000000..54841b3843
--- /dev/null
+++ b/synapse/storage/schema/delta/40/device_list_streams.sql
@@ -0,0 +1,59 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Cache of remote devices.
+CREATE TABLE device_lists_remote_cache (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    content TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
+
+
+-- The last update we got for a user. Empty if we're not receiving updates for
+-- that user.
+CREATE TABLE device_lists_remote_extremeties (
+    user_id TEXT NOT NULL,
+    stream_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
+
+
+-- Stream of device lists updates. Includes both local and remotes
+CREATE TABLE device_lists_stream (
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
+
+
+-- The stream of updates to send to other servers. We keep at least one row
+-- per user that was sent so that the prev_id for any new updates can be
+-- calculated
+CREATE TABLE device_lists_outbound_pokes (
+    destination TEXT NOT NULL,
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    sent BOOLEAN NOT NULL,
+    ts BIGINT NOT NULL  -- So that in future we can clear out pokes to dead servers
+);
+
+CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
+CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7f466c40ac..1b3800eb6a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -49,6 +49,7 @@ class StateStore(SQLBaseStore):
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
     def __init__(self, hs):
         super(StateStore, self).__init__(hs)
@@ -60,6 +61,13 @@ class StateStore(SQLBaseStore):
             self.STATE_GROUP_INDEX_UPDATE_NAME,
             self._background_index_state,
         )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, room_id, event_ids):
@@ -232,59 +240,7 @@ class StateStore(SQLBaseStore):
 
             return count
 
-    @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key=""):
-        if event_type and state_key is not None:
-            result = yield self.get_current_state_for_key(
-                room_id, event_type, state_key
-            )
-            defer.returnValue(result)
-
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? "
-            )
-
-            if event_type and state_key is not None:
-                sql += " AND type = ? AND state_key = ? "
-                args = (room_id, event_type, state_key)
-            elif event_type:
-                sql += " AND type = ?"
-                args = (room_id, event_type)
-            else:
-                args = (room_id, )
-
-            txn.execute(sql, args)
-            results = txn.fetchall()
-
-            return [r[0] for r in results]
-
-        event_ids = yield self.runInteraction("get_current_state", f)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @defer.inlineCallbacks
-    def get_current_state_for_key(self, room_id, event_type, state_key):
-        event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @cached(num_args=3)
-    def _get_current_state_for_key(self, room_id, event_type, state_key):
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? AND type = ? AND state_key = ?"
-            )
-
-            args = (room_id, event_type, state_key)
-            txn.execute(sql, args)
-            results = txn.fetchall()
-            return [r[0] for r in results]
-        return self.runInteraction("get_current_state_for_key", f)
-
-    @cached(num_args=2, max_entries=1000)
+    @cached(num_args=2, max_entries=100000, iterable=True)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 2dc24951c4..200d124632 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def get_rooms_that_changed(self, room_ids, from_key):
+        """Given a list of rooms and a token, return rooms where there may have
+        been changes.
+
+        Args:
+            room_ids (list)
+            from_key (str): The room_key portion of a StreamToken
+        """
+        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        return set(
+            room_id for room_id in room_ids
+            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+        )
+
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
                                         order='DESC'):