summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py87
-rw-r--r--synapse/storage/_base.py246
-rw-r--r--synapse/storage/account_data.py4
-rw-r--r--synapse/storage/background_updates.py15
-rw-r--r--synapse/storage/deviceinbox.py52
-rw-r--r--synapse/storage/devices.py539
-rw-r--r--synapse/storage/end_to_end_keys.py156
-rw-r--r--synapse/storage/event_federation.py56
-rw-r--r--synapse/storage/event_push_actions.py348
-rw-r--r--synapse/storage/events.py634
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/storage/presence.py14
-rw-r--r--synapse/storage/receipts.py6
-rw-r--r--synapse/storage/registration.py2
-rw-r--r--synapse/storage/room.py4
-rw-r--r--synapse/storage/roommember.py83
-rw-r--r--synapse/storage/schema/delta/40/current_state_idx.sql17
-rw-r--r--synapse/storage/schema/delta/40/device_list_streams.sql59
-rw-r--r--synapse/storage/schema/delta/40/event_push_summary.sql37
-rw-r--r--synapse/storage/schema/delta/40/pushers.sql39
-rw-r--r--synapse/storage/schema/delta/41/device_list_stream_idx.sql17
-rw-r--r--synapse/storage/schema/delta/41/device_outbound_index.sql16
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py164
-rw-r--r--synapse/storage/stream.py17
-rw-r--r--synapse/storage/tags.py4
-rw-r--r--synapse/storage/util/id_generators.py14
28 files changed, 2120 insertions, 521 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e8495f1eb9..d604e7668f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
+        self._device_list_id_gen = StreamIdGenerator(
+            db_conn, "device_lists_stream", "stream_id",
+        )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=device_outbox_prefill,
         )
 
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
@@ -286,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore,
             desc="get_user_ip_and_agents",
         )
 
+    def get_users(self):
+        """Function to reterive a list of users in users table.
+
+        Args:
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self._simple_select_list(
+            table="users",
+            keyvalues={},
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin"
+            ],
+            desc="get_users",
+        )
+
+    def get_users_paginate(self, order, start, limit):
+        """Function to reterive a paginated list of users from
+        users list. This will return a json object, which contains
+        list of users and the total number of users in users table.
+
+        Args:
+            order (str): column name to order the select by this column
+            start (int): start number to begin the query from
+            limit (int): number of rows to reterive
+        Returns:
+            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+        """
+        is_guest = 0
+        i_start = (int)(start)
+        i_limit = (int)(limit)
+        return self.get_user_list_paginate(
+            table="users",
+            keyvalues={
+                "is_guest": is_guest
+            },
+            pagevalues=[
+                order,
+                i_limit,
+                i_start
+            ],
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin"
+            ],
+            desc="get_users_paginate",
+        )
+
+    def search_users(self, term):
+        """Function to search users list for one or more users with
+        the matched term.
+
+        Args:
+            term (str): search term
+            col (str): column to query term should be matched to
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self._simple_search_list(
+            table="users",
+            term=term,
+            col="name",
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin"
+            ],
+            desc="search_users",
+        )
+
 
 def are_all_users_on_domain(txn, database_engine, domain):
     sql = database_engine.convert_param_style(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 963ef999d5..c659004e8d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,6 @@ from synapse.api.errors import StoreError
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
-from synapse.util.caches import intern_dict
 from synapse.storage.engines import PostgresEngine
 import synapse.metrics
 
@@ -74,13 +73,22 @@ class LoggingTransaction(object):
     def __setattr__(self, name, value):
         setattr(self.txn, name, value)
 
+    def __iter__(self):
+        return self.txn.__iter__()
+
     def execute(self, sql, *args):
         self._do_execute(self.txn.execute, sql, *args)
 
     def executemany(self, sql, *args):
         self._do_execute(self.txn.executemany, sql, *args)
 
+    def _make_sql_one_line(self, sql):
+        "Strip newlines out of SQL so that the loggers in the DB are on one line"
+        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
     def _do_execute(self, func, sql, *args):
+        sql = self._make_sql_one_line(sql)
+
         # TODO(paul): Maybe use 'info' and 'debug' for values?
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
@@ -127,7 +135,7 @@ class PerformanceCounters(object):
 
     def interval(self, interval_duration, limit=3):
         counters = []
-        for name, (count, cum_time) in self.current_counters.items():
+        for name, (count, cum_time) in self.current_counters.iteritems():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
             counters.append((
                 (cum_time - prev_time) / interval_duration,
@@ -350,9 +358,9 @@ class SQLBaseStore(object):
         Returns:
             A list of dicts where the key is the column header.
         """
-        col_headers = list(column[0] for column in cursor.description)
+        col_headers = list(intern(column[0]) for column in cursor.description)
         results = list(
-            intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
+            dict(zip(col_headers, row)) for row in cursor
         )
         return results
 
@@ -387,6 +395,10 @@ class SQLBaseStore(object):
         Args:
             table : string giving the table name
             values : dict of new column names and values for them
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
         """
         try:
             yield self.runInteraction(
@@ -398,6 +410,8 @@ class SQLBaseStore(object):
             # a cursor after we receive an error from the db.
             if not or_ignore:
                 raise
+            defer.returnValue(False)
+        defer.returnValue(True)
 
     @staticmethod
     def _simple_insert_txn(txn, table, values):
@@ -477,10 +491,6 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
         sqlargs = values.values() + keyvalues.values()
-        logger.debug(
-            "[SQL] %s Args=%s",
-            sql, sqlargs,
-        )
 
         txn.execute(sql, sqlargs)
         if txn.rowcount == 0:
@@ -495,10 +505,6 @@ class SQLBaseStore(object):
                 ", ".join(k for k in allvalues),
                 ", ".join("?" for _ in allvalues)
             )
-            logger.debug(
-                "[SQL] %s Args=%s",
-                sql, keyvalues.values(),
-            )
             txn.execute(sql, allvalues.values())
 
             return True
@@ -562,7 +568,7 @@ class SQLBaseStore(object):
     @staticmethod
     def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
         if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
         else:
             where = ""
 
@@ -576,7 +582,7 @@ class SQLBaseStore(object):
 
         txn.execute(sql, keyvalues.values())
 
-        return [r[0] for r in txn.fetchall()]
+        return [r[0] for r in txn]
 
     def _simple_select_onecol(self, table, keyvalues, retcol,
                               desc="_simple_select_onecol"):
@@ -709,7 +715,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.items():
+        for key, value in keyvalues.iteritems():
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -750,7 +756,7 @@ class SQLBaseStore(object):
     @staticmethod
     def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
         if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
         else:
             where = ""
 
@@ -837,6 +843,47 @@ class SQLBaseStore(object):
 
         return txn.execute(sql, keyvalues.values())
 
+    def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+        return self.runInteraction(
+            desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+        )
+
+    @staticmethod
+    def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+        """
+        if not iterable:
+            return
+
+        sql = "DELETE FROM %s" % table
+
+        clauses = []
+        values = []
+        clauses.append(
+            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+        )
+        values.extend(iterable)
+
+        for key, value in keyvalues.iteritems():
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (
+                sql,
+                " AND ".join(clauses),
+            )
+        return txn.execute(sql, values)
+
     def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
                         max_value, limit=100000):
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -857,16 +904,16 @@ class SQLBaseStore(object):
 
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))
-        rows = txn.fetchall()
-        txn.close()
 
         cache = {
             row[0]: int(row[1])
-            for row in rows
+            for row in txn
         }
 
+        txn.close()
+
         if cache:
-            min_val = min(cache.values())
+            min_val = min(cache.itervalues())
         else:
             min_val = max_value
 
@@ -928,6 +975,165 @@ class SQLBaseStore(object):
         else:
             return 0
 
+    def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
+                                     desc="_simple_select_list_paginate"):
+        """Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+            order (str): order the select by this column
+            start (int): start number to begin the query from
+            limit (int): number of rows to reterive
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc,
+            self._simple_select_list_paginate_txn,
+            table, keyvalues, pagevalues, retcols
+        )
+
+    @classmethod
+    def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
+        """Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            pagevalues ([]):
+                order (str): order the select by this column
+                start (int): start number to begin the query from
+                limit (int): number of rows to reterive
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+
+        """
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k,) for k in keyvalues),
+                " ? ASC LIMIT ? OFFSET ?"
+            )
+            txn.execute(sql, keyvalues.values() + pagevalues)
+        else:
+            sql = "SELECT %s FROM %s ORDER BY %s" % (
+                ", ".join(retcols),
+                table,
+                " ? ASC LIMIT ? OFFSET ?"
+            )
+            txn.execute(sql, pagevalues)
+
+        return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
+                               desc="get_user_list_paginate"):
+        """Get a list of users from start row to a limit number of rows. This will
+        return a json object with users and total number of users in users list.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            pagevalues ([]):
+                order (str): order the select by this column
+                start (int): start number to begin the query from
+                limit (int): number of rows to reterive
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+        """
+        users = yield self.runInteraction(
+            desc,
+            self._simple_select_list_paginate_txn,
+            table, keyvalues, pagevalues, retcols
+        )
+        count = yield self.runInteraction(
+            desc,
+            self.get_user_count_txn
+        )
+        retval = {
+            "users": users,
+            "total": count
+        }
+        defer.returnValue(retval)
+
+    def get_user_count_txn(self, txn):
+        """Get a total number of registerd users in the users list.
+
+        Args:
+            txn : Transaction object
+        Returns:
+            defer.Deferred: resolves to int
+        """
+        sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
+        txn.execute(sql_count)
+        count = txn.fetchone()[0]
+        defer.returnValue(count)
+
+    def _simple_search_list(self, table, term, col, retcols,
+                            desc="_simple_search_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+
+        return self.runInteraction(
+            desc,
+            self._simple_search_list_txn,
+            table, term, col, retcols
+        )
+
+    @classmethod
+    def _simple_search_list_txn(cls, txn, table, term, col, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+        if term:
+            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
+                ", ".join(retcols),
+                table,
+                col
+            )
+            termvalues = ["%%" + term + "%%"]
+            txn.execute(sql, termvalues)
+        else:
+            return 0
+
+        return cls.cursor_to_dict(txn)
+
 
 class _RollbackButIsFineException(Exception):
     """ This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 3fa226e92d..aa84ffc2b0 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
             txn.execute(sql, (user_id, stream_id))
 
             global_account_data = {
-                row[0]: json.loads(row[1]) for row in txn.fetchall()
+                row[0]: json.loads(row[1]) for row in txn
             }
 
             sql = (
@@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room = {}
-            for row in txn.fetchall():
+            for row in txn:
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = json.loads(row[2])
 
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 94b2bcc54a..813ad59e56 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import synapse.util.async
 
 from ._base import SQLBaseStore
 from . import engines
@@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
-        self._background_update_timer = None
 
     @defer.inlineCallbacks
     def start_doing_background_updates(self):
-        assert self._background_update_timer is None, \
-            "background updates already running"
-
         logger.info("Starting background schema updates")
 
         while True:
-            sleep = defer.Deferred()
-            self._background_update_timer = self._clock.call_later(
-                self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
-            )
-            try:
-                yield sleep
-            finally:
-                self._background_update_timer = None
+            yield synapse.util.async.sleep(
+                self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
 
             try:
                 result = yield self.do_next_background_update(
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index bde3b5cbbc..2714519d21 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -20,6 +20,8 @@ from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 
+from synapse.util.caches.expiringcache import ExpiringCache
+
 
 logger = logging.getLogger(__name__)
 
@@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
             self._background_drop_index_device_inbox,
         )
 
+        # Map of (user_id, device_id) to the last stream_id that has been
+        # deleted up to. This is so that we can no op deletions.
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
     @defer.inlineCallbacks
     def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
                                      remote_messages_by_destination):
@@ -167,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 )
                 txn.execute(sql, (user_id,))
                 message_json = ujson.dumps(messages_by_device["*"])
-                for row in txn.fetchall():
+                for row in txn:
                     # Add the message for all devices for this user on this
                     # server.
                     device = row[0]
@@ -184,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 # TODO: Maybe this needs to be done in batches if there are
                 # too many local devices for a given user.
                 txn.execute(sql, [user_id] + devices)
-                for row in txn.fetchall():
+                for row in txn:
                     # Only insert into the local inbox if the device exists on
                     # this server
                     device = row[0]
@@ -240,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 user_id, device_id, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
                 messages.append(ujson.loads(row[1]))
             if len(messages) < limit:
@@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             "get_new_messages_for_device", get_new_messages_for_device_txn,
         )
 
+    @defer.inlineCallbacks
     def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
         """
         Args:
@@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
         Returns:
             A deferred that resolves to the number of messages deleted.
         """
+        # If we have cached the last stream id we've deleted up to, we can
+        # check if there is likely to be anything that needs deleting
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), None
+        )
+        if last_deleted_stream_id:
+            has_changed = self._device_inbox_stream_cache.has_entity_changed(
+                user_id, last_deleted_stream_id
+            )
+            if not has_changed:
+                defer.returnValue(0)
+
         def delete_messages_for_device_txn(txn):
             sql = (
                 "DELETE FROM device_inbox"
@@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        return self.runInteraction(
+        count = yield self.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
+        # Update the cache, ensuring that we only ever increase the value
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), 0
+        )
+        self._last_device_delete_cache[(user_id, device_id)] = max(
+            last_deleted_stream_id, up_to_stream_id
+        )
+
+        defer.returnValue(count)
+
     def get_all_new_device_messages(self, last_pos, current_pos, limit):
         """
         Args:
@@ -306,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 " ORDER BY stream_id ASC"
             )
             txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn.fetchall())
+            rows.extend(txn)
 
             return rows
 
@@ -323,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
         """
         Args:
             destination(str): The name of the remote server.
-            last_stream_id(int): The last position of the device message stream
+            last_stream_id(int|long): The last position of the device message stream
                 that the server sent up to.
-            current_stream_id(int): The current position of the device
+            current_stream_id(int|long): The current position of the device
                 message stream.
         Returns:
-            Deferred ([dict], int): List of messages for the device and where
+            Deferred ([dict], int|long): List of messages for the device and where
                 in the stream the messages got to.
         """
 
@@ -350,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 destination, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
                 messages.append(ujson.loads(row[1]))
             if len(messages) < limit:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 17920d4480..53e36791d5 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,37 +13,49 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import ujson as json
 
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(DeviceStore, self).__init__(hs)
+
+        self._clock.looping_call(
+            self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+        )
+
+        self.register_background_index_update(
+            "device_lists_stream_idx",
+            index_name="device_lists_stream_user_id",
+            table="device_lists_stream",
+            columns=["user_id", "device_id"],
+        )
+
     @defer.inlineCallbacks
     def store_device(self, user_id, device_id,
-                     initial_device_display_name,
-                     ignore_if_known=True):
+                     initial_device_display_name):
         """Ensure the given device is known; add it to the store if not
 
         Args:
             user_id (str): id of user associated with the device
             device_id (str): id of device
             initial_device_display_name (str): initial displayname of the
-               device
-            ignore_if_known (bool): ignore integrity errors which mean the
-               device is already known
+               device. Ignored if device exists.
         Returns:
-            defer.Deferred
-        Raises:
-            StoreError: if ignore_if_known is False and the device was already
-               known
+            defer.Deferred: boolean whether the device was inserted or an
+                existing device existed with that ID.
         """
         try:
-            yield self._simple_insert(
+            inserted = yield self._simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -51,8 +63,9 @@ class DeviceStore(SQLBaseStore):
                     "display_name": initial_device_display_name
                 },
                 desc="store_device",
-                or_ignore=ignore_if_known,
+                or_ignore=True,
             )
+            defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
                          " display_name=%s(%r) failed: %s",
@@ -95,6 +108,23 @@ class DeviceStore(SQLBaseStore):
             desc="delete_device",
         )
 
+    def delete_devices(self, user_id, device_ids):
+        """Deletes several devices.
+
+        Args:
+            user_id (str): The ID of the user which owns the devices
+            device_ids (list): The IDs of the devices to delete
+        Returns:
+            defer.Deferred
+        """
+        return self._simple_delete_many(
+            table="devices",
+            column="device_id",
+            iterable=device_ids,
+            keyvalues={"user_id": user_id},
+            desc="delete_devices",
+        )
+
     def update_device(self, user_id, device_id, new_display_name=None):
         """Update a device.
 
@@ -139,3 +169,490 @@ class DeviceStore(SQLBaseStore):
         )
 
         defer.returnValue({d["device_id"]: d for d in devices})
+
+    @cached(max_entries=10000)
+    def get_device_list_last_stream_id_for_remote(self, user_id):
+        """Get the last stream_id we got for a user. May be None if we haven't
+        got any information for them.
+        """
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_remote_extremity",
+            allow_none=True,
+        )
+
+    @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+                list_name="user_ids", inlineCallbacks=True)
+    def get_device_list_last_stream_id_for_remotes(self, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="device_lists_remote_extremeties",
+            column="user_id",
+            iterable=user_ids,
+            retcols=("user_id", "stream_id",),
+            desc="get_user_devices_from_cache",
+        )
+
+        results = {user_id: None for user_id in user_ids}
+        results.update({
+            row["user_id"]: row["stream_id"] for row in rows
+        })
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
+        yield self._simple_delete(
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            desc="mark_remote_user_device_list_as_unsubscribed",
+        )
+        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+
+    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+                                              stream_id):
+        """Updates a single user's device in the cache.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache_entry",
+            self._update_remote_device_list_cache_entry_txn,
+            user_id, device_id, content, stream_id,
+        )
+
+    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+                                                   content, stream_id):
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "content": json.dumps(content),
+            }
+        )
+
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(
+            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+        """Replace the cache of the remote user's devices.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache",
+            self._update_remote_device_list_cache_txn,
+            user_id, devices, stream_id,
+        )
+
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+                                             stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_remote_cache",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": content["device_id"],
+                    "content": json.dumps(content),
+                }
+                for content in devices
+            ]
+        )
+
+        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+        txn.call_after(
+            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def get_devices_by_remote(self, destination, from_stream_id):
+        """Get stream of updates to send to remote servers
+
+        Returns:
+            (int, list[dict]): current stream id and list of updates
+        """
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+            destination, int(from_stream_id)
+        )
+        if not has_changed:
+            return (now_stream_id, [])
+
+        return self.runInteraction(
+            "get_devices_by_remote", self._get_devices_by_remote_txn,
+            destination, from_stream_id, now_stream_id,
+        )
+
+    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+                                   now_stream_id):
+        sql = """
+            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+            GROUP BY user_id, device_id
+            LIMIT 20
+        """
+        txn.execute(
+            sql, (destination, from_stream_id, now_stream_id, False)
+        )
+
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in txn}
+        if not query_map:
+            return (now_stream_id, [])
+
+        if len(query_map) >= 20:
+            now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, query_map.keys(), include_all_devices=True
+        )
+
+        prev_sent_id_sql = """
+            SELECT coalesce(max(stream_id), 0) as stream_id
+            FROM device_lists_outbound_pokes
+            WHERE destination = ? AND user_id = ? AND stream_id <= ?
+        """
+
+        results = []
+        for user_id, user_devices in devices.iteritems():
+            # The prev_id for the first row is always the last row before
+            # `from_stream_id`
+            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+            rows = txn.fetchall()
+            prev_id = rows[0][0]
+            for device_id, device in user_devices.iteritems():
+                stream_id = query_map[(user_id, device_id)]
+                result = {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "prev_id": [prev_id] if prev_id else [],
+                    "stream_id": stream_id,
+                }
+
+                prev_id = stream_id
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+        return (now_stream_id, results)
+
+    @defer.inlineCallbacks
+    def get_user_devices_from_cache(self, query_list):
+        """Get the devices (and keys if any) for remote users from the cache.
+
+        Args:
+            query_list(list): List of (user_id, device_ids), if device_ids is
+                falsey then return all device ids for that user.
+
+        Returns:
+            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+            a set of user_ids and results_map is a mapping of
+            user_id -> device_id -> device_info
+        """
+        user_ids = set(user_id for user_id, _ in query_list)
+        user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        user_ids_in_cache = set(
+            user_id for user_id, stream_id in user_map.items() if stream_id
+        )
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                device = yield self._get_cached_user_device(user_id, device_id)
+                results.setdefault(user_id, {})[device_id] = device
+            else:
+                results[user_id] = yield self._get_cached_devices_for_user(user_id)
+
+        defer.returnValue((user_ids_not_in_cache, results))
+
+    @cachedInlineCallbacks(num_args=2, tree=True)
+    def _get_cached_user_device(self, user_id, device_id):
+        content = yield self._simple_select_one_onecol(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            retcol="content",
+            desc="_get_cached_user_device",
+        )
+        defer.returnValue(json.loads(content))
+
+    @cachedInlineCallbacks()
+    def _get_cached_devices_for_user(self, user_id):
+        devices = yield self._simple_select_list(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+            retcols=("device_id", "content"),
+            desc="_get_cached_devices_for_user",
+        )
+        defer.returnValue({
+            device["device_id"]: json.loads(device["content"])
+            for device in devices
+        })
+
+    def get_devices_with_keys_by_user(self, user_id):
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.iteritems():
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
+    def mark_as_sent_devices_by_remote(self, destination, stream_id):
+        """Mark that updates have successfully been sent to the destination.
+        """
+        return self.runInteraction(
+            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+            destination, stream_id,
+        )
+
+    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+        # First we DELETE all rows such that only the latest row for each
+        # (destination, user_id is left. We do this by selecting first and
+        # deleting.
+        sql = """
+            SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND stream_id <= ?
+            GROUP BY user_id
+            HAVING count(*) > 1
+        """
+        txn.execute(sql, (destination, stream_id,))
+        rows = txn.fetchall()
+
+        sql = """
+            DELETE FROM device_lists_outbound_pokes
+            WHERE destination = ? AND user_id = ? AND stream_id < ?
+        """
+        txn.executemany(
+            sql, ((destination, row[0], row[1],) for row in rows)
+        )
+
+        # Mark everything that is left as sent
+        sql = """
+            UPDATE device_lists_outbound_pokes SET sent = ?
+            WHERE destination = ? AND stream_id <= ?
+        """
+        txn.execute(sql, (True, destination, stream_id,))
+
+    @defer.inlineCallbacks
+    def get_user_whose_devices_changed(self, from_key):
+        """Get set of users whose devices have changed since `from_key`.
+        """
+        from_key = int(from_key)
+        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+        if changed is not None:
+            defer.returnValue(set(changed))
+
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+        """
+        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        defer.returnValue(set(row[0] for row in rows))
+
+    def get_all_device_list_changes_for_remotes(self, from_key):
+        """Return a list of `(stream_id, user_id, destination)` which is the
+        combined list of changes to devices, and which destinations need to be
+        poked. `destination` may be None if no destinations need to be poked.
+        """
+        sql = """
+            SELECT stream_id, user_id, destination FROM device_lists_stream
+            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            WHERE stream_id > ?
+        """
+        return self._execute(
+            "get_all_device_list_changes_for_remotes", None,
+            sql, from_key,
+        )
+
+    @defer.inlineCallbacks
+    def add_device_change_to_streams(self, user_id, device_ids, hosts):
+        """Persist that a user's devices have been updated, and which hosts
+        (if any) should be poked.
+        """
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_device_change_to_streams", self._add_device_change_txn,
+                user_id, device_ids, hosts, stream_id,
+            )
+        defer.returnValue(stream_id)
+
+    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
+        now = self._clock.time_msec()
+
+        txn.call_after(
+            self._device_list_stream_cache.entity_has_changed,
+            user_id, stream_id,
+        )
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host, stream_id,
+            )
+
+        # Delete older entries in the table, as we really only care about
+        # when the latest change happened.
+        txn.executemany(
+            """
+            DELETE FROM device_lists_stream
+            WHERE user_id = ? AND device_id = ? AND stream_id < ?
+            """,
+            [(user_id, device_id, stream_id) for device_id in device_ids]
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_stream",
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+                for device_id in device_ids
+            ]
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_outbound_pokes",
+            values=[
+                {
+                    "destination": destination,
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "sent": False,
+                    "ts": now,
+                }
+                for destination in hosts
+                for device_id in device_ids
+            ]
+        )
+
+    def get_device_stream_token(self):
+        return self._device_list_id_gen.get_current_token()
+
+    def _prune_old_outbound_device_pokes(self):
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers. We keep one entry per
+        (destination, user_id) tuple to ensure that the prev_ids remain correct
+        if the server does come back.
+        """
+        yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+
+        def _prune_txn(txn):
+            select_sql = """
+                SELECT destination, user_id, max(stream_id) as stream_id
+                FROM device_lists_outbound_pokes
+                GROUP BY destination, user_id
+                HAVING min(ts) < ? AND count(*) > 1
+            """
+
+            txn.execute(select_sql, (yesterday,))
+            rows = txn.fetchall()
+
+            if not rows:
+                return
+
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+            """
+
+            txn.executemany(
+                delete_sql,
+                (
+                    (yesterday, row[0], row[1], row[2])
+                    for row in rows
+                )
+            )
+
+            logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+
+        return self.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn
+        )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 385d607056..7cbc1470fd 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,95 +12,173 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections
+from twisted.internet import defer
 
-import twisted.internet.defer
+from synapse.api.errors import SynapseError
+
+from canonicaljson import encode_canonical_json
+import ujson as json
 
 from ._base import SQLBaseStore
 
 
 class EndToEndKeyStore(SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
-        return self._simple_upsert(
-            table="e2e_device_keys_json",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "ts_added_ms": time_now,
-                "key_json": json_bytes,
-            }
+    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+        def _set_e2e_device_keys_txn(txn):
+            old_key_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            new_key_json = encode_canonical_json(device_keys)
+            if old_key_json == new_key_json:
+                return False
+
+            self._simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "ts_added_ms": time_now,
+                    "key_json": new_key_json,
+                }
+            )
+
+            return True
+
+        return self.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
         )
 
-    def get_e2e_device_keys(self, query_list):
+    @defer.inlineCallbacks
+    def get_e2e_device_keys(self, query_list, include_all_devices=False):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
+            include_all_devices (bool): whether to include entries for devices
+                that don't have device keys
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
         """
         if not query_list:
-            return {}
+            defer.returnValue({})
 
-        return self.runInteraction(
-            "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+        results = yield self.runInteraction(
+            "get_e2e_device_keys", self._get_e2e_device_keys_txn,
+            query_list, include_all_devices,
         )
 
-    def _get_e2e_device_keys_txn(self, txn, query_list):
+        for user_id, device_keys in results.iteritems():
+            for device_id, device_info in device_keys.iteritems():
+                device_info["keys"] = json.loads(device_info.pop("key_json"))
+
+        defer.returnValue(results)
+
+    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
         query_clauses = []
         query_params = []
 
         for (user_id, device_id) in query_list:
-            query_clause = "k.user_id = ?"
+            query_clause = "user_id = ?"
             query_params.append(user_id)
 
-            if device_id:
-                query_clause += " AND k.device_id = ?"
+            if device_id is not None:
+                query_clause += " AND device_id = ?"
                 query_params.append(device_id)
 
             query_clauses.append(query_clause)
 
         sql = (
-            "SELECT k.user_id, k.device_id, "
+            "SELECT user_id, device_id, "
             "    d.display_name AS device_display_name, "
             "    k.key_json"
-            " FROM e2e_device_keys_json k"
-            "    LEFT JOIN devices d ON d.user_id = k.user_id"
-            "      AND d.device_id = k.device_id"
+            " FROM devices d"
+            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
             " WHERE %s"
         ) % (
+            "LEFT" if include_all_devices else "INNER",
             " OR ".join("(" + q + ")" for q in query_clauses)
         )
 
         txn.execute(sql, query_params)
         rows = self.cursor_to_dict(txn)
 
-        result = collections.defaultdict(dict)
+        result = {}
         for row in rows:
-            result[row["user_id"]][row["device_id"]] = row
+            result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
         return result
 
+    @defer.inlineCallbacks
     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+        """Insert some new one time keys for a device.
+
+        Checks if any of the keys are already inserted, if they are then check
+        if they match. If they don't then we raise an error.
+        """
+
+        # First we check if we have already persisted any of the keys.
+        rows = yield self._simple_select_many_batch(
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            iterable=[key_id for _, key_id, _ in key_list],
+            retcols=("algorithm", "key_id", "key_json",),
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            desc="add_e2e_one_time_keys_check",
+        )
+
+        existing_key_map = {
+            (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
+        }
+
+        new_keys = []  # Keys that we need to insert
+        for algorithm, key_id, json_bytes in key_list:
+            ex_bytes = existing_key_map.get((algorithm, key_id), None)
+            if ex_bytes:
+                if json_bytes != ex_bytes:
+                    raise SynapseError(
+                        400, "One time key with key_id %r already exists" % (key_id,)
+                    )
+            else:
+                new_keys.append((algorithm, key_id, json_bytes))
+
         def _add_e2e_one_time_keys(txn):
-            for (algorithm, key_id, json_bytes) in key_list:
-                self._simple_upsert_txn(
-                    txn, table="e2e_one_time_keys_json",
-                    keyvalues={
+            # We are protected from race between lookup and insertion due to
+            # a unique constraint. If there is a race of two calls to
+            # `add_e2e_one_time_keys` then they'll conflict and we will only
+            # insert one set.
+            self._simple_insert_many_txn(
+                txn, table="e2e_one_time_keys_json",
+                values=[
+                    {
                         "user_id": user_id,
                         "device_id": device_id,
                         "algorithm": algorithm,
                         "key_id": key_id,
-                    },
-                    values={
                         "ts_added_ms": time_now,
                         "key_json": json_bytes,
                     }
-                )
-        return self.runInteraction(
-            "add_e2e_one_time_keys", _add_e2e_one_time_keys
+                    for algorithm, key_id, json_bytes in new_keys
+                ],
+            )
+        yield self.runInteraction(
+            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
     def count_e2e_one_time_keys(self, user_id, device_id):
@@ -116,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             txn.execute(sql, (user_id, device_id))
             result = {}
-            for algorithm, key_count in txn.fetchall():
+            for algorithm, key_count in txn:
                 result[algorithm] = key_count
             return result
         return self.runInteraction(
@@ -137,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn.fetchall():
+                for key_id, key_json in txn:
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
             sql = (
@@ -152,7 +230,7 @@ class EndToEndKeyStore(SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    @twisted.internet.defer.inlineCallbacks
+    @defer.inlineCallbacks
     def delete_e2e_keys_by_device(self, user_id, device_id):
         yield self._simple_delete(
             table="e2e_device_keys_json",
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index f0aa2193fb..43b5b49986 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
                     base_sql % (",".join(["?"] * len(chunk)),),
                     chunk
                 )
-                new_front.update([r[0] for r in txn.fetchall()])
+                new_front.update([r[0] for r in txn])
 
             new_front -= results
 
@@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
 
         txn.execute(sql, (room_id, False,))
 
-        return dict(txn.fetchall())
+        return dict(txn)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
@@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
-    @cached()
+    @cached(max_entries=5000, iterable=True)
     def get_latest_event_ids_in_room(self, room_id):
         return self._simple_select_onecol(
             table="event_forward_extremities",
@@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
         txn.execute(sql, (room_id, ))
 
         results = []
-        for event_id, depth in txn.fetchall():
+        for event_id, depth in txn:
             hashes = self._get_event_reference_hashes_txn(txn, event_id)
             prev_hashes = {
                 k: encode_base64(v) for k, v in hashes.items()
@@ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore):
     def _update_min_depth_for_room_txn(self, txn, room_id, depth):
         min_depth = self._get_min_depth_interaction(txn, room_id)
 
-        do_insert = depth < min_depth if min_depth else True
+        if min_depth and depth >= min_depth:
+            return
 
-        if do_insert:
-            self._simple_upsert_txn(
-                txn,
-                table="room_depth",
-                keyvalues={
-                    "room_id": room_id,
-                },
-                values={
-                    "min_depth": depth,
-                },
-            )
+        self._simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={
+                "room_id": room_id,
+            },
+            values={
+                "min_depth": depth,
+            },
+        )
 
     def _handle_mult_prev_events(self, txn, events):
         """
@@ -281,15 +281,30 @@ class EventFederationStore(SQLBaseStore):
         )
 
     def get_forward_extremeties_for_room(self, room_id, stream_ordering):
+        """For a given room_id and stream_ordering, return the forward
+        extremeties of the room at that point in "time".
+
+        Throws a StoreError if we have since purged the index for
+        stream_orderings from that point.
+
+        Args:
+            room_id (str):
+            stream_ordering (int):
+
+        Returns:
+            deferred, which resolves to a list of event_ids
+        """
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
         last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
 
         # We don't always have a full stream_to_exterm_id table, e.g. after
         # the upgrade that introduced it, so we make sure we never ask for a
-        # try and pin to a stream_ordering from before a restart
+        # stream_ordering from before a restart
         last_change = max(self._stream_order_on_start, last_change)
 
+        # provided the last_change is recent enough, we now clamp the requested
+        # stream_ordering to it.
         if last_change > self.stream_ordering_month_ago:
             stream_ordering = min(last_change, stream_ordering)
 
@@ -319,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
 
         def get_forward_extremeties_for_room_txn(txn):
             txn.execute(sql, (stream_ordering, room_id))
-            rows = txn.fetchall()
-            return [event_id for event_id, in rows]
+            return [event_id for event_id, in txn]
 
         return self.runInteraction(
             "get_forward_extremeties_for_room",
@@ -421,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
                 (room_id, event_id, False, limit - len(event_results))
             )
 
-            for row in txn.fetchall():
+            for row in txn:
                 if row[1] not in event_results:
                     queue.put((-row[0], row[1]))
 
@@ -467,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
                     (room_id, event_id, False, limit - len(event_results))
                 )
 
-                for e_id, in txn.fetchall():
+                for e_id, in txn:
                     new_front.add(e_id)
 
             new_front -= earliest_events
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 7de3e8c58c..d6d8723b4a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -15,6 +15,7 @@
 
 from ._base import SQLBaseStore
 from twisted.internet import defer
+from synapse.util.async import sleep
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.types import RoomStreamToken
 from .stream import lower_bound
@@ -25,11 +26,46 @@ import ujson as json
 logger = logging.getLogger(__name__)
 
 
+DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
+DEFAULT_HIGHLIGHT_ACTION = [
+    "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+]
+
+
+def _serialize_action(actions, is_highlight):
+    """Custom serializer for actions. This allows us to "compress" common actions.
+
+    We use the fact that most users have the same actions for notifs (and for
+    highlights).
+    We store these default actions as the empty string rather than the full JSON.
+    Since the empty string isn't valid JSON there is no risk of this clashing with
+    any real JSON actions
+    """
+    if is_highlight:
+        if actions == DEFAULT_HIGHLIGHT_ACTION:
+            return ""  # We use empty string as the column is non-NULL
+    else:
+        if actions == DEFAULT_NOTIF_ACTION:
+            return ""
+    return json.dumps(actions)
+
+
+def _deserialize_action(actions, is_highlight):
+    """Custom deserializer for actions. This allows us to "compress" common actions
+    """
+    if actions:
+        return json.loads(actions)
+
+    if is_highlight:
+        return DEFAULT_HIGHLIGHT_ACTION
+    else:
+        return DEFAULT_NOTIF_ACTION
+
+
 class EventPushActionsStore(SQLBaseStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
     def __init__(self, hs):
-        self.stream_ordering_month_ago = None
         super(EventPushActionsStore, self).__init__(hs)
 
         self.register_background_index_update(
@@ -47,6 +83,11 @@ class EventPushActionsStore(SQLBaseStore):
             where_clause="highlight=1"
         )
 
+        self._doing_notif_rotation = False
+        self._rotate_notif_loop = self._clock.looping_call(
+            self._rotate_notifs, 30 * 60 * 1000
+        )
+
     def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
         """
         Args:
@@ -55,15 +96,17 @@ class EventPushActionsStore(SQLBaseStore):
         """
         values = []
         for uid, actions in tuples:
+            is_highlight = 1 if _action_has_highlight(actions) else 0
+
             values.append({
                 'room_id': event.room_id,
                 'event_id': event.event_id,
                 'user_id': uid,
-                'actions': json.dumps(actions),
+                'actions': _serialize_action(actions, is_highlight),
                 'stream_ordering': event.internal_metadata.stream_ordering,
                 'topological_ordering': event.depth,
                 'notif': 1,
-                'highlight': 1 if _action_has_highlight(actions) else 0,
+                'highlight': is_highlight,
             })
 
         for uid, __ in tuples:
@@ -77,66 +120,83 @@ class EventPushActionsStore(SQLBaseStore):
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
     ):
-        def _get_unread_event_push_actions_by_room(txn):
-            sql = (
-                "SELECT stream_ordering, topological_ordering"
-                " FROM events"
-                " WHERE room_id = ? AND event_id = ?"
-            )
-            txn.execute(
-                sql, (room_id, last_read_event_id)
-            )
-            results = txn.fetchall()
-            if len(results) == 0:
-                return {"notify_count": 0, "highlight_count": 0}
-
-            stream_ordering = results[0][0]
-            topological_ordering = results[0][1]
-            token = RoomStreamToken(
-                topological_ordering, stream_ordering
-            )
-
-            # First get number of notifications.
-            # We don't need to put a notif=1 clause as all rows always have
-            # notif=1
-            sql = (
-                "SELECT count(*)"
-                " FROM event_push_actions ea"
-                " WHERE"
-                " user_id = ?"
-                " AND room_id = ?"
-                " AND %s"
-            ) % (lower_bound(token, self.database_engine, inclusive=False),)
-
-            txn.execute(sql, (user_id, room_id))
-            row = txn.fetchone()
-            notify_count = row[0] if row else 0
+        ret = yield self.runInteraction(
+            "get_unread_event_push_actions_by_room",
+            self._get_unread_counts_by_receipt_txn,
+            room_id, user_id, last_read_event_id
+        )
+        defer.returnValue(ret)
 
-            # Now get the number of highlights
-            sql = (
-                "SELECT count(*)"
-                " FROM event_push_actions ea"
-                " WHERE"
-                " highlight = 1"
-                " AND user_id = ?"
-                " AND room_id = ?"
-                " AND %s"
-            ) % (lower_bound(token, self.database_engine, inclusive=False),)
+    def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
+                                          last_read_event_id):
+        sql = (
+            "SELECT stream_ordering, topological_ordering"
+            " FROM events"
+            " WHERE room_id = ? AND event_id = ?"
+        )
+        txn.execute(
+            sql, (room_id, last_read_event_id)
+        )
+        results = txn.fetchall()
+        if len(results) == 0:
+            return {"notify_count": 0, "highlight_count": 0}
 
-            txn.execute(sql, (user_id, room_id))
-            row = txn.fetchone()
-            highlight_count = row[0] if row else 0
+        stream_ordering = results[0][0]
+        topological_ordering = results[0][1]
 
-            return {
-                "notify_count": notify_count,
-                "highlight_count": highlight_count,
-            }
+        return self._get_unread_counts_by_pos_txn(
+            txn, room_id, user_id, topological_ordering, stream_ordering
+        )
 
-        ret = yield self.runInteraction(
-            "get_unread_event_push_actions_by_room",
-            _get_unread_event_push_actions_by_room
+    def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
+                                      stream_ordering):
+        token = RoomStreamToken(
+            topological_ordering, stream_ordering
         )
-        defer.returnValue(ret)
+
+        # First get number of notifications.
+        # We don't need to put a notif=1 clause as all rows always have
+        # notif=1
+        sql = (
+            "SELECT count(*)"
+            " FROM event_push_actions ea"
+            " WHERE"
+            " user_id = ?"
+            " AND room_id = ?"
+            " AND %s"
+        ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+        txn.execute(sql, (user_id, room_id))
+        row = txn.fetchone()
+        notify_count = row[0] if row else 0
+
+        txn.execute("""
+            SELECT notif_count FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+        """, (room_id, user_id, stream_ordering,))
+        rows = txn.fetchall()
+        if rows:
+            notify_count += rows[0][0]
+
+        # Now get the number of highlights
+        sql = (
+            "SELECT count(*)"
+            " FROM event_push_actions ea"
+            " WHERE"
+            " highlight = 1"
+            " AND user_id = ?"
+            " AND room_id = ?"
+            " AND %s"
+        ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+        txn.execute(sql, (user_id, room_id))
+        row = txn.fetchone()
+        highlight_count = row[0] if row else 0
+
+        return {
+            "notify_count": notify_count,
+            "highlight_count": highlight_count,
+        }
 
     @defer.inlineCallbacks
     def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@@ -146,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
                 " stream_ordering >= ? AND stream_ordering <= ?"
             )
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
-            return [r[0] for r in txn.fetchall()]
+            return [r[0] for r in txn]
         ret = yield self.runInteraction("get_push_action_users_in_range", f)
         defer.returnValue(ret)
 
@@ -176,7 +236,8 @@ class EventPushActionsStore(SQLBaseStore):
             # find rooms that have a read receipt in them and return the next
             # push actions
             sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+                "   ep.highlight "
                 " FROM ("
                 "   SELECT room_id,"
                 "       MAX(topological_ordering) as topological_ordering,"
@@ -217,7 +278,7 @@ class EventPushActionsStore(SQLBaseStore):
         def get_no_receipt(txn):
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                " e.received_ts"
+                "   ep.highlight "
                 " FROM event_push_actions AS ep"
                 " INNER JOIN events AS e USING (room_id, event_id)"
                 " WHERE"
@@ -246,7 +307,7 @@ class EventPushActionsStore(SQLBaseStore):
                 "event_id": row[0],
                 "room_id": row[1],
                 "stream_ordering": row[2],
-                "actions": json.loads(row[3]),
+                "actions": _deserialize_action(row[3], row[4]),
             } for row in after_read_receipt + no_read_receipt
         ]
 
@@ -285,7 +346,7 @@ class EventPushActionsStore(SQLBaseStore):
         def get_after_receipt(txn):
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "  e.received_ts"
+                "  ep.highlight, e.received_ts"
                 " FROM ("
                 "   SELECT room_id,"
                 "       MAX(topological_ordering) as topological_ordering,"
@@ -327,7 +388,7 @@ class EventPushActionsStore(SQLBaseStore):
         def get_no_receipt(txn):
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                " e.received_ts"
+                "   ep.highlight, e.received_ts"
                 " FROM event_push_actions AS ep"
                 " INNER JOIN events AS e USING (room_id, event_id)"
                 " WHERE"
@@ -357,8 +418,8 @@ class EventPushActionsStore(SQLBaseStore):
                 "event_id": row[0],
                 "room_id": row[1],
                 "stream_ordering": row[2],
-                "actions": json.loads(row[3]),
-                "received_ts": row[4],
+                "actions": _deserialize_action(row[3], row[4]),
+                "received_ts": row[5],
             } for row in after_read_receipt + no_read_receipt
         ]
 
@@ -392,7 +453,7 @@ class EventPushActionsStore(SQLBaseStore):
             sql = (
                 "SELECT epa.event_id, epa.room_id,"
                 " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.profile_tag, e.received_ts"
+                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
                 " FROM event_push_actions epa, events e"
                 " WHERE epa.event_id = e.event_id"
                 " AND epa.user_id = ? %s"
@@ -407,7 +468,7 @@ class EventPushActionsStore(SQLBaseStore):
             "get_push_actions_for_user", f
         )
         for pa in push_actions:
-            pa["actions"] = json.loads(pa["actions"])
+            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         defer.returnValue(push_actions)
 
     @defer.inlineCallbacks
@@ -448,10 +509,14 @@ class EventPushActionsStore(SQLBaseStore):
         )
 
     def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
-                                            topological_ordering):
+                                            topological_ordering, stream_ordering):
         """
-        Purges old, stale push actions for a user and room before a given
-        topological_ordering
+        Purges old push actions for a user and room before a given
+        topological_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
         Args:
             txn: The transcation
             room_id: Room ID to delete from
@@ -475,10 +540,16 @@ class EventPushActionsStore(SQLBaseStore):
         txn.execute(
             "DELETE FROM event_push_actions "
             " WHERE user_id = ? AND room_id = ? AND "
-            " topological_ordering < ? AND stream_ordering < ?",
+            " topological_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
             (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
         )
 
+        txn.execute("""
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """, (room_id, user_id, stream_ordering))
+
     @defer.inlineCallbacks
     def _find_stream_orderings_for_times(self):
         yield self.runInteraction(
@@ -495,6 +566,14 @@ class EventPushActionsStore(SQLBaseStore):
             "Found stream ordering 1 month ago: it's %d",
             self.stream_ordering_month_ago
         )
+        logger.info("Searching for stream ordering 1 day ago")
+        self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 day ago: it's %d",
+            self.stream_ordering_day_ago
+        )
 
     def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
         """
@@ -534,6 +613,131 @@ class EventPushActionsStore(SQLBaseStore):
 
         return range_end
 
+    @defer.inlineCallbacks
+    def _rotate_notifs(self):
+        if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
+            return
+        self._doing_notif_rotation = True
+
+        try:
+            while True:
+                logger.info("Rotating notifications")
+
+                caught_up = yield self.runInteraction(
+                    "_rotate_notifs",
+                    self._rotate_notifs_txn
+                )
+                if caught_up:
+                    break
+                yield sleep(5)
+        finally:
+            self._doing_notif_rotation = False
+
+    def _rotate_notifs_txn(self, txn):
+        """Archives older notifications into event_push_summary. Returns whether
+        the archiving process has caught up or not.
+        """
+
+        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
+        # We don't to try and rotate millions of rows at once, so we cap the
+        # maximum stream ordering we'll rotate before.
+        txn.execute("""
+            SELECT stream_ordering FROM event_push_actions
+            WHERE stream_ordering > ?
+            ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
+        """, (old_rotate_stream_ordering,))
+        stream_row = txn.fetchone()
+        if stream_row:
+            offset_stream_ordering, = stream_row
+            rotate_to_stream_ordering = min(
+                self.stream_ordering_day_ago, offset_stream_ordering
+            )
+            caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+        else:
+            rotate_to_stream_ordering = self.stream_ordering_day_ago
+            caught_up = True
+
+        logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
+
+        self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+
+        # We have caught up iff we were limited by `stream_ordering_day_ago`
+        return caught_up
+
+    def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
+        # Calculate the new counts that should be upserted into event_push_summary
+        sql = """
+            SELECT user_id, room_id,
+                coalesce(old.notif_count, 0) + upd.notif_count,
+                upd.stream_ordering,
+                old.user_id
+            FROM (
+                SELECT user_id, room_id, count(*) as notif_count,
+                    max(stream_ordering) as stream_ordering
+                FROM event_push_actions
+                WHERE ? <= stream_ordering AND stream_ordering < ?
+                    AND highlight = 0
+                GROUP BY user_id, room_id
+            ) AS upd
+            LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+        """
+
+        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
+        rows = txn.fetchall()
+
+        logger.info("Rotating notifications, handling %d rows", len(rows))
+
+        # If the `old.user_id` above is NULL then we know there isn't already an
+        # entry in the table, so we simply insert it. Otherwise we update the
+        # existing table.
+        self._simple_insert_many_txn(
+            txn,
+            table="event_push_summary",
+            values=[
+                {
+                    "user_id": row[0],
+                    "room_id": row[1],
+                    "notif_count": row[2],
+                    "stream_ordering": row[3],
+                }
+                for row in rows if row[4] is None
+            ]
+        )
+
+        txn.executemany(
+            """
+                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                WHERE user_id = ? AND room_id = ?
+            """,
+            ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
+        )
+
+        txn.execute(
+            "DELETE FROM event_push_actions"
+            " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
+            (old_rotate_stream_ordering, rotate_to_stream_ordering,)
+        )
+
+        logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+
+        txn.execute(
+            "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
+            (rotate_to_stream_ordering,)
+        )
+
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6160949f32..3f6833fad2 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -28,19 +28,22 @@ from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.state import resolve_events
+from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
 from functools import wraps
 
-import synapse
 import synapse.metrics
 
-
 import logging
 import math
 import ujson as json
 
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
+
 logger = logging.getLogger(__name__)
 
 
@@ -81,6 +84,11 @@ class _EventPeristenceQueue(object):
 
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
+
+        Args:
+            room_id (str):
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
@@ -209,14 +217,14 @@ class EventsStore(SQLBaseStore):
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
         deferreds = []
-        for room_id, evs_ctxs in partitioned.items():
+        for room_id, evs_ctxs in partitioned.iteritems():
             d = preserve_fn(self._event_persist_queue.add_to_queue)(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
             deferreds.append(d)
 
-        for room_id in partitioned.keys():
+        for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
         return preserve_context_over_deferred(
@@ -226,6 +234,17 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context, backfilled=False):
+        """
+
+        Args:
+            event (EventBase):
+            context (EventContext):
+            backfilled (bool):
+
+        Returns:
+            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            and the stream ordering of the latest persisted event
+        """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
@@ -252,6 +271,16 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     def _persist_events(self, events_and_contexts, backfilled=False,
                         delete_existing=False):
+        """Persist events to db
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+            delete_existing (bool):
+
+        Returns:
+            Deferred: resolves when the events have been persisted
+        """
         if not events_and_contexts:
             return
 
@@ -284,71 +313,52 @@ class EventsStore(SQLBaseStore):
                 new_forward_extremeties = {}
                 current_state_for_room = {}
                 if not backfilled:
-                    # Work out the new "current state" for each room.
-                    # We do this by working out what the new extremities are and then
-                    # calculating the state from that.
-                    events_by_room = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in events_by_room.items():
-                        # Work out new extremities by recursively adding and removing
-                        # the new events.
-                        latest_event_ids = yield self.get_latest_event_ids_in_room(
-                            room_id
-                        )
-                        new_latest_event_ids = yield self._calculate_new_extremeties(
-                            room_id, [ev for ev, _ in ev_ctx_rm]
-                        )
-
-                        if new_latest_event_ids == set(latest_event_ids):
-                            # No change in extremities, so no change in state
-                            continue
+                    with Measure(self._clock, "_calculate_state_and_extrem"):
+                        # Work out the new "current state" for each room.
+                        # We do this by working out what the new extremities are and then
+                        # calculating the state from that.
+                        events_by_room = {}
+                        for event, context in chunk:
+                            events_by_room.setdefault(event.room_id, []).append(
+                                (event, context)
+                            )
 
-                        new_forward_extremeties[room_id] = new_latest_event_ids
-
-                        # Now we need to work out the different state sets for
-                        # each state extremities
-                        state_sets = []
-                        missing_event_ids = []
-                        was_updated = False
-                        for event_id in new_latest_event_ids:
-                            # First search in the list of new events we're adding,
-                            # and then use the current state from that
-                            for ev, ctx in ev_ctx_rm:
-                                if event_id == ev.event_id:
-                                    if ctx.current_state_ids is None:
-                                        raise Exception("Unknown current state")
-                                    state_sets.append(ctx.current_state_ids)
-                                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                                        was_updated = True
-                                    break
-                            else:
-                                # If we couldn't find it, then we'll need to pull
-                                # the state from the database
-                                was_updated = True
-                                missing_event_ids.append(event_id)
-
-                        if missing_event_ids:
-                            # Now pull out the state for any missing events from DB
-                            event_to_groups = yield self._get_state_group_for_events(
-                                missing_event_ids,
+                        for room_id, ev_ctx_rm in events_by_room.iteritems():
+                            # Work out new extremities by recursively adding and removing
+                            # the new events.
+                            latest_event_ids = yield self.get_latest_event_ids_in_room(
+                                room_id
+                            )
+                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                                room_id, ev_ctx_rm, latest_event_ids
                             )
 
-                            groups = set(event_to_groups.values())
-                            group_to_state = yield self._get_state_for_groups(groups)
+                            if new_latest_event_ids == set(latest_event_ids):
+                                # No change in extremities, so no change in state
+                                continue
 
-                            state_sets.extend(group_to_state.values())
+                            new_forward_extremeties[room_id] = new_latest_event_ids
 
-                        if not new_latest_event_ids or was_updated:
-                            current_state_for_room[room_id] = yield resolve_events(
-                                state_sets,
-                                state_map_factory=lambda ev_ids: self.get_events(
-                                    ev_ids, get_prev_content=False, check_redacted=False,
-                                ),
+                            len_1 = (
+                                len(latest_event_ids) == 1
+                                and len(new_latest_event_ids) == 1
                             )
+                            if len_1:
+                                all_single_prev_not_state = all(
+                                    len(event.prev_events) == 1
+                                    and not event.is_state()
+                                    for event, ctx in ev_ctx_rm
+                                )
+                                # Don't bother calculating state if they're just
+                                # a long chain of single ancestor non-state events.
+                                if all_single_prev_not_state:
+                                    continue
+
+                            state = yield self._calculate_state_delta(
+                                room_id, ev_ctx_rm, new_latest_event_ids
+                            )
+                            if state:
+                                current_state_for_room[room_id] = state
 
                 yield self.runInteraction(
                     "persist_events",
@@ -362,27 +372,24 @@ class EventsStore(SQLBaseStore):
                 persist_event_counter.inc_by(len(chunk))
 
     @defer.inlineCallbacks
-    def _calculate_new_extremeties(self, room_id, events):
+    def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
         """Calculates the new forward extremeties for a room given events to
         persist.
 
         Assumes that we are only persisting events for one room at a time.
         """
-        latest_event_ids = yield self.get_latest_event_ids_in_room(
-            room_id
-        )
         new_latest_event_ids = set(latest_event_ids)
         # First, add all the new events to the list
         new_latest_event_ids.update(
-            event.event_id for event in events
-            if not event.internal_metadata.is_outlier()
+            event.event_id for event, ctx in event_contexts
+            if not event.internal_metadata.is_outlier() and not ctx.rejected
         )
         # Now remove all events that are referenced by the to-be-added events
         new_latest_event_ids.difference_update(
             e_id
-            for event in events
+            for event, ctx in event_contexts
             for e_id, _ in event.prev_events
-            if not event.internal_metadata.is_outlier()
+            if not event.internal_metadata.is_outlier() and not ctx.rejected
         )
 
         # And finally remove any events that are referenced by previously added
@@ -406,6 +413,125 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
+            (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+            May return None if there are no changes to be applied.
+        """
+        # Now we need to work out the different state sets for
+        # each state extremities
+        state_sets = []
+        state_groups = set()
+        missing_event_ids = []
+        was_updated = False
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding,
+            # and then use the current state from that
+            for ev, ctx in events_context:
+                if event_id == ev.event_id:
+                    if ctx.current_state_ids is None:
+                        raise Exception("Unknown current state")
+
+                    # If we've already seen the state group don't bother adding
+                    # it to the state sets again
+                    if ctx.state_group not in state_groups:
+                        state_sets.append(ctx.current_state_ids)
+                        if ctx.delta_ids or hasattr(ev, "state_key"):
+                            was_updated = True
+                        if ctx.state_group:
+                            # Add this as a seen state group (if it has a state
+                            # group)
+                            state_groups.add(ctx.state_group)
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                was_updated = True
+                missing_event_ids.append(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state for any missing events from DB
+            event_to_groups = yield self._get_state_group_for_events(
+                missing_event_ids,
+            )
+
+            groups = set(event_to_groups.itervalues()) - state_groups
+
+            if groups:
+                group_to_state = yield self._get_state_for_groups(groups)
+                state_sets.extend(group_to_state.itervalues())
+
+        if not new_latest_event_ids:
+            current_state = {}
+        elif was_updated:
+            if len(state_sets) == 1:
+                # If there is only one state set, then we know what the current
+                # state is.
+                current_state = state_sets[0]
+            else:
+                # We work out the current state by passing the state sets to the
+                # state resolution algorithm. It may ask for some events, including
+                # the events we have yet to persist, so we need a slightly more
+                # complicated event lookup function than simply looking the events
+                # up in the db.
+                events_map = {ev.event_id: ev for ev, _ in events_context}
+
+                @defer.inlineCallbacks
+                def get_events(ev_ids):
+                    # We get the events by first looking at the list of events we
+                    # are trying to persist, and then fetching the rest from the DB.
+                    db = []
+                    to_return = {}
+                    for ev_id in ev_ids:
+                        ev = events_map.get(ev_id, None)
+                        if ev:
+                            to_return[ev_id] = ev
+                        else:
+                            db.append(ev_id)
+
+                    if db:
+                        evs = yield self.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
+                        )
+                        to_return.update(evs)
+                    defer.returnValue(to_return)
+
+                current_state = yield resolve_events(
+                    state_sets,
+                    state_map_factory=get_events,
+                )
+        else:
+            return
+
+        existing_state = yield self.get_current_state_ids(room_id)
+
+        existing_events = set(existing_state.itervalues())
+        new_events = set(ev_id for ev_id in current_state.itervalues())
+        changed_events = existing_events ^ new_events
+
+        if not changed_events:
+            return
+
+        to_delete = {
+            key: ev_id for key, ev_id in existing_state.iteritems()
+            if ev_id in changed_events
+        }
+        events_to_insert = (new_events - existing_events)
+        to_insert = {
+            key: ev_id for key, ev_id in current_state.iteritems()
+            if ev_id in events_to_insert
+        }
+
+        defer.returnValue((to_delete, to_insert))
+
+    @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,
                   allow_none=False):
@@ -471,44 +597,143 @@ class EventsStore(SQLBaseStore):
         and the rejections table. Things reading from those table will need to check
         whether the event was rejected.
 
-        If delete_existing is True then existing events will be purged from the
-        database before insertion. This is useful when retrying due to IntegrityError.
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]):
+                events to persist
+            backfilled (bool): True if the events were backfilled
+            delete_existing (bool): True to purge existing table rows for the
+                events from the database. This is useful when retrying due to
+                IntegrityError.
+            current_state_for_room (dict[str, (list[str], list[str])]):
+                The current-state delta for each room. For each room, a tuple
+                (to_delete, to_insert), being a list of event ids to be removed
+                from the current state, and a list of event ids to be added to
+                the current state.
+            new_forward_extremeties (dict[str, list[str]]):
+                The new forward extremities for each room. For each room, a
+                list of the event ids which are the forward extremities.
+
         """
+        self._update_current_state_txn(txn, current_state_for_room)
+
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-        for room_id, current_state in current_state_for_room.iteritems():
-            txn.call_after(self._get_current_state_for_key.invalidate_all)
-            txn.call_after(self.get_rooms_for_user.invalidate_all)
-            txn.call_after(self.get_users_in_room.invalidate, (room_id,))
-
-            # Add an entry to the current_state_resets table to record the point
-            # where we clobbered the current state
-            self._simple_insert_txn(
-                txn,
-                table="current_state_resets",
-                values={"event_stream_ordering": max_stream_order}
-            )
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremeties,
+            max_stream_order=max_stream_order,
+        )
 
-            self._simple_delete_txn(
-                txn,
-                table="current_state_events",
-                keyvalues={"room_id": room_id},
-            )
+        # Ensure that we don't have the same event twice.
+        events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+            events_and_contexts,
+        )
 
-            self._simple_insert_many_txn(
+        self._update_room_depths_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # _update_outliers_txn filters out any events which have already been
+        # persisted, and returns the filtered list.
+        events_and_contexts = self._update_outliers_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # From this point onwards the events are only events that we haven't
+        # seen before.
+
+        if delete_existing:
+            # For paranoia reasons, we go and delete all the existing entries
+            # for these events so we can reinsert them.
+            # This gets around any problems with some tables already having
+            # entries.
+            self._delete_existing_rows_txn(
                 txn,
-                table="current_state_events",
-                values=[
-                    {
-                        "event_id": ev_id,
-                        "room_id": room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                    }
-                    for key, ev_id in current_state.iteritems()
-                ],
+                events_and_contexts=events_and_contexts,
             )
 
-        for room_id, new_extrem in new_forward_extremeties.items():
+        self._store_event_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # Insert into the state_groups, state_groups_state, and
+        # event_to_state_groups tables.
+        self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
+
+    def _update_current_state_txn(self, txn, state_delta_by_room):
+        for room_id, current_state_tuple in state_delta_by_room.iteritems():
+                to_delete, to_insert = current_state_tuple
+                txn.executemany(
+                    "DELETE FROM current_state_events WHERE event_id = ?",
+                    [(ev_id,) for ev_id in to_delete.itervalues()],
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="current_state_events",
+                    values=[
+                        {
+                            "event_id": ev_id,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                        }
+                        for key, ev_id in to_insert.iteritems()
+                    ],
+                )
+
+                # Invalidate the various caches
+
+                # Figure out the changes of membership to invalidate the
+                # `get_rooms_for_user` cache.
+                # We find out which membership events we may have deleted
+                # and which we have added, then we invlidate the caches for all
+                # those users.
+                members_changed = set(
+                    state_key for ev_type, state_key in to_delete.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+                members_changed.update(
+                    state_key for ev_type, state_key in to_insert.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+
+                for member in members_changed:
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_rooms_for_user, (member,)
+                    )
+
+                self._invalidate_cache_and_stream(
+                    txn, self.get_users_in_room, (room_id,)
+                )
+
+                self._invalidate_cache_and_stream(
+                    txn, self.get_current_state_ids, (room_id,)
+                )
+
+    def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+                                        max_stream_order):
+        for room_id, new_extrem in new_forward_extremities.iteritems():
             self._simple_delete_txn(
                 txn,
                 table="event_forward_extremities",
@@ -526,7 +751,7 @@ class EventsStore(SQLBaseStore):
                     "event_id": ev_id,
                     "room_id": room_id,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for ev_id in new_extrem
             ],
         )
@@ -543,13 +768,22 @@ class EventsStore(SQLBaseStore):
                     "event_id": event_id,
                     "stream_ordering": max_stream_order,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for event_id in new_extrem
             ]
         )
 
-        # Ensure that we don't have the same event twice.
-        # Pick the earliest non-outlier if there is one, else the earliest one.
+    @classmethod
+    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+        """Ensure that we don't have the same event twice.
+
+        Pick the earliest non-outlier if there is one, else the earliest one.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+        Returns:
+            list[(EventBase, EventContext)]: filtered list
+        """
         new_events_and_contexts = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -562,9 +796,17 @@ class EventsStore(SQLBaseStore):
                         new_events_and_contexts[event.event_id] = (event, context)
             else:
                 new_events_and_contexts[event.event_id] = (event, context)
+        return new_events_and_contexts.values()
 
-        events_and_contexts = new_events_and_contexts.values()
+    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+        """Update min_depth for each room
 
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
@@ -580,9 +822,24 @@ class EventsStore(SQLBaseStore):
                     event.depth, depth_updates.get(event.room_id, event.depth)
                 )
 
-        for room_id, depth in depth_updates.items():
+        for room_id, depth in depth_updates.iteritems():
             self._update_min_depth_for_room_txn(txn, room_id, depth)
 
+    def _update_outliers_txn(self, txn, events_and_contexts):
+        """Update any outliers with new event info.
+
+        This turns outliers into ex-outliers (unless the new event was
+        rejected).
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without events which
+            are already in the events table.
+        """
         txn.execute(
             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
                 ",".join(["?"] * len(events_and_contexts)),
@@ -592,24 +849,21 @@ class EventsStore(SQLBaseStore):
 
         have_persisted = {
             event_id: outlier
-            for event_id, outlier in txn.fetchall()
+            for event_id, outlier in txn
         }
 
         to_remove = set()
         for event, context in events_and_contexts:
-            if context.rejected:
-                # If the event is rejected then we don't care if the event
-                # was an outlier or not.
-                if event.event_id in have_persisted:
-                    # If we have already seen the event then ignore it.
-                    to_remove.add(event)
-                continue
-
             if event.event_id not in have_persisted:
                 continue
 
             to_remove.add(event)
 
+            if context.rejected:
+                # If the event is rejected then we don't care if the event
+                # was an outlier or not.
+                continue
+
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
@@ -664,37 +918,19 @@ class EventsStore(SQLBaseStore):
                 # event isn't an outlier any more.
                 self._update_backward_extremeties(txn, [event])
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    @classmethod
+    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only events that we haven't
-        # seen before.
+        logger.info("Deleting existing")
 
-        def event_dict(event):
-            return {
-                k: v
-                for k, v in event.get_dict().items()
-                if k not in [
-                    "redacted",
-                    "redacted_because",
-                ]
-            }
-
-        if delete_existing:
-            # For paranoia reasons, we go and delete all the existing entries
-            # for these events so we can reinsert them.
-            # This gets around any problems with some tables already having
-            # entries.
-
-            logger.info("Deleting existing")
-
-            for table in (
+        for table in (
                 "events",
                 "event_auth",
                 "event_json",
@@ -717,11 +953,30 @@ class EventsStore(SQLBaseStore):
                 "redactions",
                 "room_memberships",
                 "topics"
-            ):
-                txn.executemany(
-                    "DELETE FROM %s WHERE event_id = ?" % (table,),
-                    [(ev.event_id,) for ev, _ in events_and_contexts]
-                )
+        ):
+            txn.executemany(
+                "DELETE FROM %s WHERE event_id = ?" % (table,),
+                [(ev.event_id,) for ev, _ in events_and_contexts]
+            )
+
+    def _store_event_txn(self, txn, events_and_contexts):
+        """Insert new events into the event and event_json tables
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+        """
+
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        def event_dict(event):
+            d = event.get_dict()
+            d.pop("redacted", None)
+            d.pop("redacted_because", None)
+            return d
 
         self._simple_insert_many_txn(
             txn,
@@ -765,6 +1020,19 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
+    def _store_rejected_events_txn(self, txn, events_and_contexts):
+        """Add rows to the 'rejections' table for received events which were
+        rejected
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without the rejected
+                events.
+        """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
         to_remove = set()
@@ -776,17 +1044,24 @@ class EventsStore(SQLBaseStore):
                 )
                 to_remove.add(event)
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+        """Update all the miscellaneous tables for new events
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
+
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only ones that weren't rejected.
-
         for event, context in events_and_contexts:
             # Insert all the push actions into the event_push_actions table.
             if context.push_actions:
@@ -815,10 +1090,6 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
-
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
@@ -905,13 +1176,6 @@ class EventsStore(SQLBaseStore):
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-        if backfilled:
-            # Backfilled events come before the current state so we don't need
-            # to update the current state table
-            return
-
-        return
-
     def _add_to_cache(self, txn, events_and_contexts):
         to_prefill = []
 
@@ -1507,6 +1771,7 @@ class EventsStore(SQLBaseStore):
         """The current minimum token that backfilled events have reached"""
         return -self._backfill_id_gen.get_current_token()
 
+    @cached(num_args=5, max_entries=10)
     def get_all_new_events(self, last_backfill_id, last_forward_id,
                            current_backfill_id, current_forward_id, limit):
         """Get all the new events that have arrived at the server either as
@@ -1519,14 +1784,13 @@ class EventsStore(SQLBaseStore):
 
         def get_all_new_events_txn(txn):
             sql = (
-                "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
             if have_forward_events:
@@ -1539,15 +1803,6 @@ class EventsStore(SQLBaseStore):
                     upper_bound = current_forward_id
 
                 sql = (
-                    "SELECT event_stream_ordering FROM current_state_resets"
-                    " WHERE ? < event_stream_ordering"
-                    " AND event_stream_ordering <= ?"
-                    " ORDER BY event_stream_ordering ASC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                state_resets = txn.fetchall()
-
-                sql = (
                     "SELECT event_stream_ordering, event_id, state_group"
                     " FROM ex_outlier_stream"
                     " WHERE ? > event_stream_ordering"
@@ -1558,19 +1813,16 @@ class EventsStore(SQLBaseStore):
                 forward_ex_outliers = txn.fetchall()
             else:
                 new_forward_events = []
-                state_resets = []
                 forward_ex_outliers = []
 
             sql = (
-                "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
-                " eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
-                " ORDER BY e.stream_ordering DESC"
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
                 " LIMIT ?"
             )
             if have_backfill_events:
@@ -1598,7 +1850,6 @@ class EventsStore(SQLBaseStore):
             return AllNewEventsResult(
                 new_forward_events, new_backfill_events,
                 forward_ex_outliers, backward_ex_outliers,
-                state_resets,
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
@@ -1758,7 +2009,7 @@ class EventsStore(SQLBaseStore):
                         "state_key": key[1],
                         "event_id": state_id,
                     }
-                    for key, state_id in curr_state.items()
+                    for key, state_id in curr_state.iteritems()
                 ],
             )
 
@@ -1824,5 +2075,4 @@ class EventsStore(SQLBaseStore):
 AllNewEventsResult = namedtuple("AllNewEventsResult", [
     "new_forward_events", "new_backfill_events",
     "forward_ex_outliers", "backward_ex_outliers",
-    "state_resets"
 ])
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9ddd..3b5e0a4fb9 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
         key_ids
         Args:
             server_name (str): The name of the server.
-            key_ids (list of str): List of key_ids to try and look up.
+            key_ids (iterable[str]): key_ids to try and look up.
         Returns:
-            (list of VerifyKey): The verification keys.
+            Deferred: resolves to dict[str, VerifyKey]: map from
+               key_id to verification key.
         """
         keys = {}
         for key_id in key_ids:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b357f22be7..6e623843d5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 40
+SCHEMA_VERSION = 41
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
             ),
             (current_version,)
         )
-        applied_deltas = [d for d, in txn.fetchall()]
+        applied_deltas = [d for d, in txn]
         return current_version, applied_deltas, upgraded
 
     return None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 7460f98a1f..9e9d3c2591 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -15,7 +15,7 @@
 
 from ._base import SQLBaseStore
 from synapse.api.constants import PresenceState
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 from collections import namedtuple
 from twisted.internet import defer
@@ -85,6 +85,9 @@ class PresenceStore(SQLBaseStore):
                 self.presence_stream_cache.entity_has_changed,
                 state.user_id, stream_id,
             )
+            txn.call_after(
+                self._get_presence_for_user.invalidate, (state.user_id,)
+            )
 
         # Actually insert new rows
         self._simple_insert_many_txn(
@@ -143,7 +146,12 @@ class PresenceStore(SQLBaseStore):
             "get_all_presence_updates", get_all_presence_updates_txn
         )
 
-    @defer.inlineCallbacks
+    @cached()
+    def _get_presence_for_user(self, user_id):
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids",
+                num_args=1, inlineCallbacks=True)
     def get_presence_for_users(self, user_ids):
         rows = yield self._simple_select_many_batch(
             table="presence_stream",
@@ -165,7 +173,7 @@ class PresenceStore(SQLBaseStore):
         for row in rows:
             row["currently_active"] = bool(row["currently_active"])
 
-        defer.returnValue([UserPresenceState(**row) for row in rows])
+        defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f72d15f5ed..6b0f8c2787 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
         )
 
         txn.execute(sql, (room_id, receipt_type, user_id))
-        results = txn.fetchall()
 
-        if results and topological_ordering:
-            for to, so, _ in results:
+        if topological_ordering:
+            for to, so, _ in txn:
                 if int(to) > topological_ordering:
                     return False
                 elif int(to) == topological_ordering and int(so) >= stream_ordering:
@@ -351,6 +350,7 @@ class ReceiptsStore(SQLBaseStore):
                 room_id=room_id,
                 user_id=user_id,
                 topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
             )
 
         return True
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26be6060c3..ec2c52ab93 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
-            return dict(txn.fetchall())
+            return dict(txn)
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8a2fe2fdf5..e4c56cc175 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
                     sql % ("AND appservice_id IS NULL",),
                     (stream_id,)
                 )
-            return dict(txn.fetchall())
+            return dict(txn)
         else:
             # We want to get from all lists, so we need to aggregate the results
 
@@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
 
             results = {}
             # A room is visible if its visible on any list.
-            for room_id, visibility in txn.fetchall():
+            for room_id, visibility in txn:
                 results[room_id] = bool(visibility) or results.get(room_id, False)
 
             return results
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 768e0a4451..367dbbbcf6 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore):
         )
 
         for event in events:
-            txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
-            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
                 event.state_key, event.internal_metadata.stream_ordering
@@ -131,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
-    @cached(max_entries=5000)
+    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+    def get_hosts_in_room(self, room_id, cache_context):
+        """Returns the set of all hosts currently in the room
+        """
+        user_ids = yield self.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate,
+        )
+        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+        defer.returnValue(hosts)
+
+    @cached(max_entries=500000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
-
-            rows = self._get_members_rows_txn(
-                txn,
-                room_id=room_id,
-                membership=Membership.JOIN,
+            sql = (
+                "SELECT m.user_id FROM room_memberships as m"
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id "
+                " AND m.room_id = c.room_id "
+                " AND m.user_id = c.state_key"
+                " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
             )
 
-            return [r["user_id"] for r in rows]
+            txn.execute(sql, (room_id, Membership.JOIN,))
+            return [r[0] for r in txn]
         return self.runInteraction("get_users_in_room", f)
 
     @cached()
@@ -220,7 +231,7 @@ class RoomMemberStore(SQLBaseStore):
                 " ON e.event_id = c.event_id"
                 " AND m.room_id = c.room_id"
                 " AND m.user_id = c.state_key"
-                " WHERE %s"
+                " WHERE c.type = 'm.room.member' AND %s"
             ) % (where_clause,)
 
             txn.execute(sql, args)
@@ -248,39 +259,31 @@ class RoomMemberStore(SQLBaseStore):
 
         return results
 
-    def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
-        where_clause = "c.room_id = ?"
-        where_values = [room_id]
-
-        if membership:
-            where_clause += " AND m.membership = ?"
-            where_values.append(membership)
-
-        if user_id:
-            where_clause += " AND m.user_id = ?"
-            where_values.append(user_id)
-
-        sql = (
-            "SELECT m.* FROM room_memberships as m"
-            " INNER JOIN current_state_events as c"
-            " ON m.event_id = c.event_id "
-            " AND m.room_id = c.room_id "
-            " AND m.user_id = c.state_key"
-            " WHERE %(where)s"
-        ) % {
-            "where": where_clause,
-        }
-
-        txn.execute(sql, where_values)
-        rows = self.cursor_to_dict(txn)
-
-        return rows
-
-    @cached(max_entries=5000)
+    @cachedInlineCallbacks(max_entries=500000, iterable=True)
     def get_rooms_for_user(self, user_id):
-        return self.get_rooms_for_user_where_membership_is(
+        """Returns a set of room_ids the user is currently joined to
+        """
+        rooms = yield self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
+        defer.returnValue(frozenset(r.room_id for r in rooms))
+
+    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+    def get_users_who_share_room_with_user(self, user_id, cache_context):
+        """Returns the set of users who share a room with `user_id`
+        """
+        room_ids = yield self.get_rooms_for_user(
+            user_id, on_invalidate=cache_context.invalidate,
+        )
+
+        user_who_share_room = set()
+        for room_id in room_ids:
+            user_ids = yield self.get_users_in_room(
+                room_id, on_invalidate=cache_context.invalidate,
+            )
+            user_who_share_room.update(user_ids)
+
+        defer.returnValue(user_who_share_room)
 
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql
new file mode 100644
index 0000000000..7ffa189f39
--- /dev/null
+++ b/synapse/storage/schema/delta/40/current_state_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('current_state_members_idx', '{}');
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql
new file mode 100644
index 0000000000..54841b3843
--- /dev/null
+++ b/synapse/storage/schema/delta/40/device_list_streams.sql
@@ -0,0 +1,59 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Cache of remote devices.
+CREATE TABLE device_lists_remote_cache (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    content TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
+
+
+-- The last update we got for a user. Empty if we're not receiving updates for
+-- that user.
+CREATE TABLE device_lists_remote_extremeties (
+    user_id TEXT NOT NULL,
+    stream_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
+
+
+-- Stream of device lists updates. Includes both local and remotes
+CREATE TABLE device_lists_stream (
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
+
+
+-- The stream of updates to send to other servers. We keep at least one row
+-- per user that was sent so that the prev_id for any new updates can be
+-- calculated
+CREATE TABLE device_lists_outbound_pokes (
+    destination TEXT NOT NULL,
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    sent BOOLEAN NOT NULL,
+    ts BIGINT NOT NULL  -- So that in future we can clear out pokes to dead servers
+);
+
+CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
+CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql
new file mode 100644
index 0000000000..3918f0b794
--- /dev/null
+++ b/synapse/storage/schema/delta/40/event_push_summary.sql
@@ -0,0 +1,37 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Aggregate of old notification counts that have been deleted out of the
+-- main event_push_actions table. This count does not include those that were
+-- highlights, as they remain in the event_push_actions table.
+CREATE TABLE event_push_summary (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    notif_count BIGINT NOT NULL,
+    stream_ordering BIGINT NOT NULL
+);
+
+CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id);
+
+
+-- The stream ordering up to which we have aggregated the event_push_actions
+-- table into event_push_summary
+CREATE TABLE event_push_summary_stream_ordering (
+    Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE,  -- Makes sure this table only has one row.
+    stream_ordering BIGINT NOT NULL,
+    CHECK (Lock='X')
+);
+
+INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql
new file mode 100644
index 0000000000..054a223f14
--- /dev/null
+++ b/synapse/storage/schema/delta/40/pushers.sql
@@ -0,0 +1,39 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS pushers2 (
+    id BIGINT PRIMARY KEY,
+    user_name TEXT NOT NULL,
+    access_token BIGINT DEFAULT NULL,
+    profile_tag TEXT NOT NULL,
+    kind TEXT NOT NULL,
+    app_id TEXT NOT NULL,
+    app_display_name TEXT NOT NULL,
+    device_display_name TEXT NOT NULL,
+    pushkey TEXT NOT NULL,
+    ts BIGINT NOT NULL,
+    lang TEXT,
+    data TEXT,
+    last_stream_ordering INTEGER,
+    last_success BIGINT,
+    failing_since BIGINT,
+    UNIQUE (app_id, pushkey, user_name)
+);
+
+INSERT INTO pushers2 SELECT * FROM PUSHERS;
+
+DROP TABLE PUSHERS;
+
+ALTER TABLE pushers2 RENAME TO pushers;
diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
new file mode 100644
index 0000000000..b7bee8b692
--- /dev/null
+++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('device_lists_stream_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql
new file mode 100644
index 0000000000..62f0b9892b
--- /dev/null
+++ b/synapse/storage/schema/delta/41/device_outbound_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e1dca927d7..67d5d9969a 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
             " WHERE event_id = ?"
         )
         txn.execute(query, (event_id, ))
-        return {k: v for k, v in txn.fetchall()}
+        return {k: v for k, v in txn}
 
     def _store_event_reference_hashes_txn(self, txn, events):
         """Store a hash for a PDU
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7d34dd03bf..314216f039 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 from synapse.util.caches import intern_string
 from synapse.storage.engines import PostgresEngine
 
@@ -49,6 +49,7 @@ class StateStore(SQLBaseStore):
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
     def __init__(self, hs):
         super(StateStore, self).__init__(hs)
@@ -60,6 +61,25 @@ class StateStore(SQLBaseStore):
             self.STATE_GROUP_INDEX_UPDATE_NAME,
             self._background_index_state,
         )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+
+    @cachedInlineCallbacks(max_entries=100000, iterable=True)
+    def get_current_state_ids(self, room_id):
+        rows = yield self._simple_select_list(
+            table="current_state_events",
+            keyvalues={"room_id": room_id},
+            retcols=["event_id", "type", "state_key"],
+            desc="_calculate_state_delta",
+        )
+        defer.returnValue({
+            (r["type"], r["state_key"]): r["event_id"] for r in rows
+        })
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, room_id, event_ids):
@@ -70,7 +90,7 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups)
 
         defer.returnValue(group_to_state)
@@ -88,17 +108,18 @@ class StateStore(SQLBaseStore):
 
         state_event_map = yield self.get_events(
             [
-                ev_id for group_ids in group_to_ids.values()
-                for ev_id in group_ids.values()
+                ev_id for group_ids in group_to_ids.itervalues()
+                for ev_id in group_ids.itervalues()
             ],
             get_prev_content=False
         )
 
         defer.returnValue({
             group: [
-                state_event_map[v] for v in event_id_map.values() if v in state_event_map
+                state_event_map[v] for v in event_id_map.itervalues()
+                if v in state_event_map
             ]
-            for group, event_id_map in group_to_ids.items()
+            for group, event_id_map in group_to_ids.iteritems()
         })
 
     def _have_persisted_state_group_txn(self, txn, state_group):
@@ -116,6 +137,16 @@ class StateStore(SQLBaseStore):
                 continue
 
             if context.current_state_ids is None:
+                # AFAIK, this can never happen
+                logger.error(
+                    "Non-outlier event %s had current_state_ids==None",
+                    event.event_id)
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
                 continue
 
             state_groups[event.event_id] = context.state_group
@@ -160,7 +191,7 @@ class StateStore(SQLBaseStore):
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.delta_ids.items()
+                        for key, state_id in context.delta_ids.iteritems()
                     ],
                 )
             else:
@@ -175,7 +206,7 @@ class StateStore(SQLBaseStore):
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.current_state_ids.items()
+                        for key, state_id in context.current_state_ids.iteritems()
                     ],
                 )
 
@@ -187,7 +218,7 @@ class StateStore(SQLBaseStore):
                     "state_group": state_group_id,
                     "event_id": event_id,
                 }
-                for event_id, state_group_id in state_groups.items()
+                for event_id, state_group_id in state_groups.iteritems()
             ],
         )
 
@@ -232,58 +263,6 @@ class StateStore(SQLBaseStore):
 
             return count
 
-    @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key=""):
-        if event_type and state_key is not None:
-            result = yield self.get_current_state_for_key(
-                room_id, event_type, state_key
-            )
-            defer.returnValue(result)
-
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? "
-            )
-
-            if event_type and state_key is not None:
-                sql += " AND type = ? AND state_key = ? "
-                args = (room_id, event_type, state_key)
-            elif event_type:
-                sql += " AND type = ?"
-                args = (room_id, event_type)
-            else:
-                args = (room_id, )
-
-            txn.execute(sql, args)
-            results = txn.fetchall()
-
-            return [r[0] for r in results]
-
-        event_ids = yield self.runInteraction("get_current_state", f)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @defer.inlineCallbacks
-    def get_current_state_for_key(self, room_id, event_type, state_key):
-        event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @cached(num_args=3)
-    def _get_current_state_for_key(self, room_id, event_type, state_key):
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? AND type = ? AND state_key = ?"
-            )
-
-            args = (room_id, event_type, state_key)
-            txn.execute(sql, args)
-            results = txn.fetchall()
-            return [r[0] for r in results]
-        return self.runInteraction("get_current_state_for_key", f)
-
     @cached(num_args=2, max_entries=100000, iterable=True)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
@@ -363,10 +342,10 @@ class StateStore(SQLBaseStore):
                     args.extend(where_args)
 
                     txn.execute(sql % (where_clause,), args)
-                    rows = self.cursor_to_dict(txn)
-                    for row in rows:
-                        key = (row["type"], row["state_key"])
-                        results[group][key] = row["event_id"]
+                    for row in txn:
+                        typ, state_key, event_id = row
+                        key = (typ, state_key)
+                        results[group][key] = event_id
         else:
             if types is not None:
                 where_clause = "AND (%s)" % (
@@ -395,12 +374,11 @@ class StateStore(SQLBaseStore):
                         " WHERE state_group = ? %s" % (where_clause,),
                         args
                     )
-                    rows = txn.fetchall()
-                    results[group].update({
-                        (typ, state_key): event_id
-                        for typ, state_key, event_id in rows
+                    results[group].update(
+                        ((typ, state_key), event_id)
+                        for typ, state_key, event_id in txn
                         if (typ, state_key) not in results[group]
-                    })
+                    )
 
                     # If the lengths match then we must have all the types,
                     # so no need to go walk further down the tree.
@@ -437,37 +415,49 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         state_event_map = yield self.get_events(
-            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+            [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
             get_prev_content=False
         )
 
         event_to_state = {
             event_id: {
                 k: state_event_map[v]
-                for k, v in group_to_state[group].items()
+                for k, v in group_to_state[group].iteritems()
                 if v in state_event_map
             }
-            for event_id, group in event_to_groups.items()
+            for event_id, group in event_to_groups.iteritems()
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, types):
+    def get_state_ids_for_events(self, event_ids, types=None):
+        """
+        Get the state dicts corresponding to a list of events
+
+        Args:
+            event_ids(list(str)): events whose state should be returned
+            types(list[(str, str)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. May be None, which
+                matches any key
+
+        Returns:
+            A deferred dict from event_id -> (type, state_key) -> state_event
+        """
         event_to_groups = yield self._get_state_group_for_events(
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         event_to_state = {
             event_id: group_to_state[group]
-            for event_id, group in event_to_groups.items()
+            for event_id, group in event_to_groups.iteritems()
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -579,7 +569,7 @@ class StateStore(SQLBaseStore):
         got_all = not (missing_types or types is None)
 
         return {
-            k: v for k, v in state_dict_ids.items()
+            k: v for k, v in state_dict_ids.iteritems()
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -638,7 +628,7 @@ class StateStore(SQLBaseStore):
 
             # Now we want to update the cache with all the things we fetched
             # from the database.
-            for group, group_state_dict in group_to_state_dict.items():
+            for group, group_state_dict in group_to_state_dict.iteritems():
                 if types:
                     # We delibrately put key -> None mappings into the cache to
                     # cache absence of the key, on the assumption that if we've
@@ -653,10 +643,10 @@ class StateStore(SQLBaseStore):
                 else:
                     state_dict = results[group]
 
-                state_dict.update({
-                    (intern_string(k[0]), intern_string(k[1])): v
-                    for k, v in group_state_dict.items()
-                })
+                state_dict.update(
+                    ((intern_string(k[0]), intern_string(k[1])), v)
+                    for k, v in group_state_dict.iteritems()
+                )
 
                 self._state_group_cache.update(
                     cache_seq_num,
@@ -667,10 +657,10 @@ class StateStore(SQLBaseStore):
 
         # Remove all the entries with None values. The None values were just
         # used for bookkeeping in the cache.
-        for group, state_dict in results.items():
+        for group, state_dict in results.iteritems():
             results[group] = {
                 key: event_id
-                for key, event_id in state_dict.items()
+                for key, event_id in state_dict.iteritems()
                 if event_id
             }
 
@@ -759,7 +749,7 @@ class StateStore(SQLBaseStore):
                         # of keys
 
                         delta_state = {
-                            key: value for key, value in curr_state.items()
+                            key: value for key, value in curr_state.iteritems()
                             if prev_state.get(key, None) != value
                         }
 
@@ -799,7 +789,7 @@ class StateStore(SQLBaseStore):
                                     "state_key": key[1],
                                     "event_id": state_id,
                                 }
-                                for key, state_id in delta_state.items()
+                                for key, state_id in delta_state.iteritems()
                             ],
                         )
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 2dc24951c4..dddd5fc0e7 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def get_rooms_that_changed(self, room_ids, from_key):
+        """Given a list of rooms and a token, return rooms where there may have
+        been changes.
+
+        Args:
+            room_ids (list)
+            from_key (str): The room_key portion of a StreamToken
+        """
+        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        return set(
+            room_id for room_id in room_ids
+            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+        )
+
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
                                         order='DESC'):
@@ -815,3 +829,6 @@ class StreamStore(SQLBaseStore):
             updatevalues={"stream_id": stream_id},
             desc="update_federation_out_pos",
         )
+
+    def has_room_changed_since(self, room_id, stream_id):
+        return self._events_stream_cache.has_entity_changed(room_id, stream_id)
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 5a2c1aa59b..bff73f3f04 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
             for stream_id, user_id, room_id in tag_ids:
                 txn.execute(sql, (user_id, room_id))
                 tags = []
-                for tag, content in txn.fetchall():
+                for tag, content in txn:
                     tags.append(json.dumps(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
                 results.append((stream_id, user_id, room_id, tag_json))
@@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
                 " WHERE user_id = ? AND stream_id > ?"
             )
             txn.execute(sql, (user_id, stream_id))
-            room_ids = [row[0] for row in txn.fetchall()]
+            room_ids = [row[0] for row in txn]
             return room_ids
 
         changed = self._account_data_stream_cache.has_entity_changed(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 46cf93ff87..95031dc9ec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -30,6 +30,17 @@ class IdGenerator(object):
 
 
 def _load_current_id(db_conn, table, column, step=1):
+    """
+
+    Args:
+        db_conn (object):
+        table (str):
+        column (str):
+        step (int):
+
+    Returns:
+        int
+    """
     cur = db_conn.cursor()
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@@ -131,6 +142,9 @@ class StreamIdGenerator(object):
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
+
+        Returns:
+            int
         """
         with self._lock:
             if self._unfinished_ids: