summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py8
-rw-r--r--synapse/storage/background_updates.py98
-rw-r--r--synapse/storage/client_ips.py4
-rw-r--r--synapse/storage/deviceinbox.py11
-rw-r--r--synapse/storage/devices.py31
-rw-r--r--synapse/storage/end_to_end_keys.py56
-rw-r--r--synapse/storage/events.py249
-rw-r--r--synapse/storage/push_rule.py13
-rw-r--r--synapse/storage/pusher.py42
-rw-r--r--synapse/storage/receipts.py11
-rw-r--r--synapse/storage/room.py36
-rw-r--r--synapse/storage/roommember.py149
-rw-r--r--synapse/storage/schema/delta/37/remove_auth_idx.py4
-rw-r--r--synapse/storage/schema/delta/41/event_search_event_id_idx.sql17
-rw-r--r--synapse/storage/schema/delta/41/ratelimit.sql22
-rw-r--r--synapse/storage/state.py60
16 files changed, 666 insertions, 145 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c659004e8d..58b73af7d2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -60,12 +60,12 @@ class LoggingTransaction(object):
         object.__setattr__(self, "database_engine", database_engine)
         object.__setattr__(self, "after_callbacks", after_callbacks)
 
-    def call_after(self, callback, *args):
+    def call_after(self, callback, *args, **kwargs):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
         """
-        self.after_callbacks.append((callback, args))
+        self.after_callbacks.append((callback, args, kwargs))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -319,8 +319,8 @@ class SQLBaseStore(object):
                     inner_func, *args, **kwargs
                 )
         finally:
-            for after_callback, after_args in after_callbacks:
-                after_callback(*after_args)
+            for after_callback, after_args, after_kwargs in after_callbacks:
+                after_callback(*after_args, **after_kwargs)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 813ad59e56..7157fb1dfb 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -210,7 +210,9 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_handlers[update_name] = update_handler
 
     def register_background_index_update(self, update_name, index_name,
-                                         table, columns, where_clause=None):
+                                         table, columns, where_clause=None,
+                                         unique=False,
+                                         psql_only=False):
         """Helper for store classes to do a background index addition
 
         To use:
@@ -226,48 +228,80 @@ class BackgroundUpdateStore(SQLBaseStore):
             index_name (str): name of index to add
             table (str): table to add index to
             columns (list[str]): columns/expressions to include in index
+            unique (bool): true to make a UNIQUE index
+            psql_only: true to only create this index on psql databases (useful
+                for virtual sqlite tables)
         """
 
-        # if this is postgres, we add the indexes concurrently. Otherwise
-        # we fall back to doing it inline
-        if isinstance(self.database_engine, engines.PostgresEngine):
-            conc = True
-        else:
-            conc = False
-            # We don't use partial indices on SQLite as it wasn't introduced
-            # until 3.8, and wheezy has 3.7
-            where_clause = None
-
-        sql = (
-            "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)"
-            " %(where_clause)s"
-        ) % {
-            "conc": "CONCURRENTLY" if conc else "",
-            "name": index_name,
-            "table": table,
-            "columns": ", ".join(columns),
-            "where_clause": "WHERE " + where_clause if where_clause else ""
-        }
-
-        def create_index_concurrently(conn):
+        def create_index_psql(conn):
             conn.rollback()
             # postgres insists on autocommit for the index
             conn.set_session(autocommit=True)
-            c = conn.cursor()
-            c.execute(sql)
-            conn.set_session(autocommit=False)
 
-        def create_index(conn):
+            try:
+                c = conn.cursor()
+
+                # If a previous attempt to create the index was interrupted,
+                # we may already have a half-built index. Let's just drop it
+                # before trying to create it again.
+
+                sql = "DROP INDEX IF EXISTS %s" % (index_name,)
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+
+                sql = (
+                    "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
+                    " ON %(table)s"
+                    " (%(columns)s) %(where_clause)s"
+                ) % {
+                    "unique": "UNIQUE" if unique else "",
+                    "name": index_name,
+                    "table": table,
+                    "columns": ", ".join(columns),
+                    "where_clause": "WHERE " + where_clause if where_clause else ""
+                }
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+            finally:
+                conn.set_session(autocommit=False)
+
+        def create_index_sqlite(conn):
+            # Sqlite doesn't support concurrent creation of indexes.
+            #
+            # We don't use partial indices on SQLite as it wasn't introduced
+            # until 3.8, and wheezy has 3.7
+            #
+            # We assume that sqlite doesn't give us invalid indices; however
+            # we may still end up with the index existing but the
+            # background_updates not having been recorded if synapse got shut
+            # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite
+            # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
+            sql = (
+                "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s"
+                " (%(columns)s)"
+            ) % {
+                "unique": "UNIQUE" if unique else "",
+                "name": index_name,
+                "table": table,
+                "columns": ", ".join(columns),
+            }
+
             c = conn.cursor()
+            logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
+        if isinstance(self.database_engine, engines.PostgresEngine):
+            runner = create_index_psql
+        elif psql_only:
+            runner = None
+        else:
+            runner = create_index_sqlite
+
         @defer.inlineCallbacks
         def updater(progress, batch_size):
-            logger.info("Adding index %s to %s", index_name, table)
-            if conc:
-                yield self.runWithConnection(create_index_concurrently)
-            else:
-                yield self.runWithConnection(create_index)
+            if runner is not None:
+                logger.info("Adding index %s to %s", index_name, table)
+                yield self.runWithConnection(runner)
             yield self._end_background_update(update_name)
             defer.returnValue(1)
 
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 71e5ea112f..747d2df622 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -33,6 +33,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
             keylen=4,
+            max_entries=5000,
         )
 
         super(ClientIpStore, self).__init__(hs)
@@ -120,6 +121,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 where_clauses.append("(user_id = ? AND device_id = ?)")
                 bindings.extend((user_id, device_id))
 
+        if not where_clauses:
+            return []
+
         inner_select = (
             "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
             "WHERE %(where)s "
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 2714519d21..0b62b493d5 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -325,23 +325,26 @@ class DeviceInboxStore(BackgroundUpdateStore):
             # we return.
             upper_pos = min(current_pos, last_pos + limit)
             sql = (
-                "SELECT stream_id, user_id"
+                "SELECT max(stream_id), user_id"
                 " FROM device_inbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
+                " GROUP BY user_id"
             )
             txn.execute(sql, (last_pos, upper_pos))
             rows = txn.fetchall()
 
             sql = (
-                "SELECT stream_id, destination"
+                "SELECT max(stream_id), destination"
                 " FROM device_federation_outbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
+                " GROUP BY destination"
             )
             txn.execute(sql, (last_pos, upper_pos))
             rows.extend(txn)
 
+            # Order by ascending stream ordering
+            rows.sort()
+
             return rows
 
         return self.runInteraction(
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 53e36791d5..d9936c88bb 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -18,7 +18,7 @@ import ujson as json
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, Cache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
 
@@ -29,6 +29,14 @@ class DeviceStore(SQLBaseStore):
     def __init__(self, hs):
         super(DeviceStore, self).__init__(hs)
 
+        # Map of (user_id, device_id) -> bool. If there is an entry that implies
+        # the device exists.
+        self.device_id_exists_cache = Cache(
+            name="device_id_exists",
+            keylen=2,
+            max_entries=10000,
+        )
+
         self._clock.looping_call(
             self._prune_old_outbound_device_pokes, 60 * 60 * 1000
         )
@@ -54,6 +62,10 @@ class DeviceStore(SQLBaseStore):
             defer.Deferred: boolean whether the device was inserted or an
                 existing device existed with that ID.
         """
+        key = (user_id, device_id)
+        if self.device_id_exists_cache.get(key, None):
+            defer.returnValue(False)
+
         try:
             inserted = yield self._simple_insert(
                 "devices",
@@ -65,6 +77,7 @@ class DeviceStore(SQLBaseStore):
                 desc="store_device",
                 or_ignore=True,
             )
+            self.device_id_exists_cache.prefill(key, True)
             defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -93,6 +106,7 @@ class DeviceStore(SQLBaseStore):
             desc="get_device",
         )
 
+    @defer.inlineCallbacks
     def delete_device(self, user_id, device_id):
         """Delete a device.
 
@@ -102,12 +116,15 @@ class DeviceStore(SQLBaseStore):
         Returns:
             defer.Deferred
         """
-        return self._simple_delete_one(
+        yield self._simple_delete_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id},
             desc="delete_device",
         )
 
+        self.device_id_exists_cache.invalidate((user_id, device_id))
+
+    @defer.inlineCallbacks
     def delete_devices(self, user_id, device_ids):
         """Deletes several devices.
 
@@ -117,13 +134,15 @@ class DeviceStore(SQLBaseStore):
         Returns:
             defer.Deferred
         """
-        return self._simple_delete_many(
+        yield self._simple_delete_many(
             table="devices",
             column="device_id",
             iterable=device_ids,
             keyvalues={"user_id": user_id},
             desc="delete_devices",
         )
+        for device_id in device_ids:
+            self.device_id_exists_cache.invalidate((user_id, device_id))
 
     def update_device(self, user_id, device_id, new_display_name=None):
         """Update a device.
@@ -533,7 +552,7 @@ class DeviceStore(SQLBaseStore):
         rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
         defer.returnValue(set(row[0] for row in rows))
 
-    def get_all_device_list_changes_for_remotes(self, from_key):
+    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
         combined list of changes to devices, and which destinations need to be
         poked. `destination` may be None if no destinations need to be poked.
@@ -541,11 +560,11 @@ class DeviceStore(SQLBaseStore):
         sql = """
             SELECT stream_id, user_id, destination FROM device_lists_stream
             LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
-            WHERE stream_id > ?
+            WHERE ? < stream_id AND stream_id <= ?
         """
         return self._execute(
             "get_all_device_list_changes_for_remotes", None,
-            sql, from_key,
+            sql, from_key, to_key
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 7cbc1470fd..e00f31da2b 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError
+from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
 import ujson as json
@@ -123,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
         return result
 
     @defer.inlineCallbacks
-    def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
-        """Insert some new one time keys for a device.
+    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+        """Retrieve a number of one-time keys for a user
 
-        Checks if any of the keys are already inserted, if they are then check
-        if they match. If they don't then we raise an error.
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            key_ids(list[str]): list of key ids (excluding algorithm) to
+                retrieve
+
+        Returns:
+            deferred resolving to Dict[(str, str), str]: map from (algorithm,
+            key_id) to json string for key
         """
 
-        # First we check if we have already persisted any of the keys.
         rows = yield self._simple_select_many_batch(
             table="e2e_one_time_keys_json",
             column="key_id",
-            iterable=[key_id for _, key_id, _ in key_list],
+            iterable=key_ids,
             retcols=("algorithm", "key_id", "key_json",),
             keyvalues={
                 "user_id": user_id,
@@ -143,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
             desc="add_e2e_one_time_keys_check",
         )
 
-        existing_key_map = {
+        defer.returnValue({
             (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
-        }
-
-        new_keys = []  # Keys that we need to insert
-        for algorithm, key_id, json_bytes in key_list:
-            ex_bytes = existing_key_map.get((algorithm, key_id), None)
-            if ex_bytes:
-                if json_bytes != ex_bytes:
-                    raise SynapseError(
-                        400, "One time key with key_id %r already exists" % (key_id,)
-                    )
-            else:
-                new_keys.append((algorithm, key_id, json_bytes))
+        })
+
+    @defer.inlineCallbacks
+    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+        """Insert some new one time keys for a device. Errors if any of the
+        keys already exist.
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            time_now(long): insertion time to record (ms since epoch)
+            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+                (algorithm, key_id, key json)
+        """
 
         def _add_e2e_one_time_keys(txn):
             # We are protected from race between lookup and insertion due to
@@ -177,10 +185,14 @@ class EndToEndKeyStore(SQLBaseStore):
                     for algorithm, key_id, json_bytes in new_keys
                 ],
             )
+            txn.call_after(
+                self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+            )
         yield self.runInteraction(
             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
+    @cached(max_entries=10000)
     def count_e2e_one_time_keys(self, user_id, device_id):
         """ Count the number of one time keys the server has for a device
         Returns:
@@ -225,6 +237,9 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             for user_id, device_id, algorithm, key_id in delete:
                 txn.execute(sql, (user_id, device_id, algorithm, key_id))
+                txn.call_after(
+                    self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+                )
             return result
         return self.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
@@ -242,3 +257,4 @@ class EndToEndKeyStore(SQLBaseStore):
             keyvalues={"user_id": user_id, "device_id": device_id},
             desc="delete_e2e_one_time_keys_by_device"
         )
+        self.count_e2e_one_time_keys.invalidate((user_id, device_id,))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 3f6833fad2..c4aeb48800 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.state import resolve_events
 from synapse.util.caches.descriptors import cached
+from synapse.types import get_domain_from_id
 
 from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
@@ -49,6 +50,9 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 persist_event_counter = metrics.register_counter("persisted_events")
+event_counter = metrics.register_counter(
+    "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
+)
 
 
 def encode_json(json_object):
@@ -203,6 +207,18 @@ class EventsStore(SQLBaseStore):
             where_clause="contains_url = true AND outlier = false",
         )
 
+        # an event_id index on event_search is useful for the purge_history
+        # api. Plus it means we get to enforce some integrity with a UNIQUE
+        # clause
+        self.register_background_index_update(
+            "event_search_event_id_idx",
+            index_name="event_search_event_id_idx",
+            table="event_search",
+            columns=["event_id"],
+            unique=True,
+            psql_only=True,
+        )
+
         self._event_persist_queue = _EventPeristenceQueue()
 
     def persist_events(self, events_and_contexts, backfilled=False):
@@ -370,6 +386,23 @@ class EventsStore(SQLBaseStore):
                     new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
+                for event, context in chunk:
+                    if context.app_service:
+                        origin_type = "local"
+                        origin_entity = context.app_service.id
+                    elif self.hs.is_mine_id(event.sender):
+                        origin_type = "local"
+                        origin_entity = "*client*"
+                    else:
+                        origin_type = "remote"
+                        origin_entity = get_domain_from_id(event.sender)
+
+                    event_counter.inc(event.type, origin_type, origin_entity)
+
+                for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+                    self.get_current_state_ids.prefill(
+                        (room_id, ), new_state
+                    )
 
     @defer.inlineCallbacks
     def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
@@ -419,10 +452,10 @@ class EventsStore(SQLBaseStore):
         Assumes that we are only persisting events for one room at a time.
 
         Returns:
-            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
-            (type, state_key) -> event_id. `to_delete` are the entries to
+            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
+            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
             first be deleted from current_state_events, `to_insert` are entries
-            to insert.
+            to insert. `new_state` is the full set of state.
             May return None if there are no changes to be applied.
         """
         # Now we need to work out the different state sets for
@@ -529,7 +562,7 @@ class EventsStore(SQLBaseStore):
             if ev_id in events_to_insert
         }
 
-        defer.returnValue((to_delete, to_insert))
+        defer.returnValue((to_delete, to_insert, current_state))
 
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
@@ -682,7 +715,7 @@ class EventsStore(SQLBaseStore):
 
     def _update_current_state_txn(self, txn, state_delta_by_room):
         for room_id, current_state_tuple in state_delta_by_room.iteritems():
-                to_delete, to_insert = current_state_tuple
+                to_delete, to_insert, _ = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
                     [(ev_id,) for ev_id in to_delete.itervalues()],
@@ -1327,11 +1360,26 @@ class EventsStore(SQLBaseStore):
     def _invalidate_get_event_cache(self, event_id):
             self._get_event_cache.invalidate((event_id,))
 
-    def _get_events_from_cache(self, events, allow_rejected):
+    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+        """Fetch events from the caches
+
+        Args:
+            events (list(str)): list of event_ids to fetch
+            allow_rejected (bool): Whether to teturn events that were rejected
+            update_metrics (bool): Whether to update the cache hit ratio metrics
+
+        Returns:
+            dict of event_id -> _EventCacheEntry for each event_id in cache. If
+            allow_rejected is `False` then there will still be an entry but it
+            will be `None`
+        """
         event_map = {}
 
         for event_id in events:
-            ret = self._get_event_cache.get((event_id,), None)
+            ret = self._get_event_cache.get(
+                (event_id,), None,
+                update_metrics=update_metrics,
+            )
             if not ret:
                 continue
 
@@ -1771,6 +1819,94 @@ class EventsStore(SQLBaseStore):
         """The current minimum token that backfilled events have reached"""
         return -self._backfill_id_gen.get_current_token()
 
+    def get_current_events_token(self):
+        """The current maximum token that events have reached"""
+        return self._stream_id_gen.get_current_token()
+
+    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_forward_event_rows(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < event_stream_ordering"
+                " AND event_stream_ordering <= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (last_id, upper_bound))
+            new_event_updates.extend(txn)
+
+            return new_event_updates
+        return self.runInteraction(
+            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+        )
+
+    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_backfill_event_rows(txn):
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (-last_id, -current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > event_stream_ordering"
+                " AND event_stream_ordering >= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (-last_id, -upper_bound))
+            new_event_updates.extend(txn.fetchall())
+
+            return new_event_updates
+        return self.runInteraction(
+            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+        )
+
     @cached(num_args=5, max_entries=10)
     def get_all_new_events(self, last_backfill_id, last_forward_id,
                            current_backfill_id, current_forward_id, limit):
@@ -1903,6 +2039,8 @@ class EventsStore(SQLBaseStore):
                 400, "topological_ordering is greater than forward extremeties"
             )
 
+        logger.debug("[purge] looking for events to delete")
+
         txn.execute(
             "SELECT event_id, state_key FROM events"
             " LEFT JOIN state_events USING (room_id, event_id)"
@@ -1911,9 +2049,19 @@ class EventsStore(SQLBaseStore):
         )
         event_rows = txn.fetchall()
 
+        to_delete = [
+            (event_id,) for event_id, state_key in event_rows
+            if state_key is None and not self.hs.is_mine_id(event_id)
+        ]
+        logger.info(
+            "[purge] found %i events before cutoff, of which %i are remote"
+            " non-state events to delete", len(event_rows), len(to_delete))
+
         for event_id, state_key in event_rows:
             txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
+        logger.debug("[purge] Finding new backward extremities")
+
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
         txn.execute(
@@ -1926,6 +2074,8 @@ class EventsStore(SQLBaseStore):
         )
         new_backwards_extrems = txn.fetchall()
 
+        logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
         txn.execute(
             "DELETE FROM event_backward_extremities WHERE room_id = ?",
             (room_id,)
@@ -1940,6 +2090,8 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
+        logger.debug("[purge] finding redundant state groups")
+
         # Get all state groups that are only referenced by events that are
         # to be deleted.
         txn.execute(
@@ -1955,15 +2107,20 @@ class EventsStore(SQLBaseStore):
         )
 
         state_rows = txn.fetchall()
-        state_groups_to_delete = [sg for sg, in state_rows]
+        logger.debug("[purge] found %i redundant state groups", len(state_rows))
+
+        # make a set of the redundant state groups, so that we can look them up
+        # efficiently
+        state_groups_to_delete = set([sg for sg, in state_rows])
 
         # Now we get all the state groups that rely on these state groups
-        new_state_edges = []
-        chunks = [
-            state_groups_to_delete[i:i + 100]
-            for i in xrange(0, len(state_groups_to_delete), 100)
-        ]
-        for chunk in chunks:
+        logger.debug("[purge] finding state groups which depend on redundant"
+                     " state groups")
+        remaining_state_groups = []
+        for i in xrange(0, len(state_rows), 100):
+            chunk = [sg for sg, in state_rows[i:i + 100]]
+            # look for state groups whose prev_state_group is one we are about
+            # to delete
             rows = self._simple_select_many_txn(
                 txn,
                 table="state_group_edges",
@@ -1972,21 +2129,28 @@ class EventsStore(SQLBaseStore):
                 retcols=["state_group"],
                 keyvalues={},
             )
-            new_state_edges.extend(row["state_group"] for row in rows)
+            remaining_state_groups.extend(
+                row["state_group"] for row in rows
+
+                # exclude state groups we are about to delete: no point in
+                # updating them
+                if row["state_group"] not in state_groups_to_delete
+            )
 
-        # Now we turn the state groups that reference to-be-deleted state groups
-        # to non delta versions.
-        for new_state_edge in new_state_edges:
+        # Now we turn the state groups that reference to-be-deleted state
+        # groups to non delta versions.
+        for sg in remaining_state_groups:
+            logger.debug("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
-                txn, [new_state_edge], types=None
+                txn, [sg], types=None
             )
-            curr_state = curr_state[new_state_edge]
+            curr_state = curr_state[sg]
 
             self._simple_delete_txn(
                 txn,
                 table="state_groups_state",
                 keyvalues={
-                    "state_group": new_state_edge,
+                    "state_group": sg,
                 }
             )
 
@@ -1994,7 +2158,7 @@ class EventsStore(SQLBaseStore):
                 txn,
                 table="state_group_edges",
                 keyvalues={
-                    "state_group": new_state_edge,
+                    "state_group": sg,
                 }
             )
 
@@ -2003,7 +2167,7 @@ class EventsStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": new_state_edge,
+                        "state_group": sg,
                         "room_id": room_id,
                         "type": key[0],
                         "state_key": key[1],
@@ -2013,6 +2177,7 @@ class EventsStore(SQLBaseStore):
                 ],
             )
 
+        logger.debug("[purge] removing redundant state groups")
         txn.executemany(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             state_rows
@@ -2021,22 +2186,21 @@ class EventsStore(SQLBaseStore):
             "DELETE FROM state_groups WHERE id = ?",
             state_rows
         )
+
         # Delete all non-state
+        logger.debug("[purge] removing events from event_to_state_groups")
         txn.executemany(
             "DELETE FROM event_to_state_groups WHERE event_id = ?",
             [(event_id,) for event_id, _ in event_rows]
         )
 
+        logger.debug("[purge] updating room_depth")
         txn.execute(
             "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
             (topological_ordering, room_id,)
         )
 
         # Delete all remote non-state events
-        to_delete = [
-            (event_id,) for event_id, state_key in event_rows
-            if state_key is None and not self.hs.is_mine_id(event_id)
-        ]
         for table in (
             "events",
             "event_json",
@@ -2052,16 +2216,15 @@ class EventsStore(SQLBaseStore):
             "event_signatures",
             "rejections",
         ):
+            logger.debug("[purge] removing remote non-state events from %s", table)
+
             txn.executemany(
                 "DELETE FROM %s WHERE event_id = ?" % (table,),
                 to_delete
             )
 
-        txn.executemany(
-            "DELETE FROM events WHERE event_id = ?",
-            to_delete
-        )
         # Mark all state and own events as outliers
+        logger.debug("[purge] marking remaining events as outliers")
         txn.executemany(
             "UPDATE events SET outlier = ?"
             " WHERE event_id = ?",
@@ -2071,6 +2234,30 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
+        logger.info("[purge] done")
+
+    @defer.inlineCallbacks
+    def is_event_after(self, event_id1, event_id2):
+        """Returns True if event_id1 is after event_id2 in the stream
+        """
+        to_1, so_1 = yield self._get_event_ordering(event_id1)
+        to_2, so_2 = yield self._get_event_ordering(event_id2)
+        defer.returnValue((to_1, so_1) > (to_2, so_2))
+
+    @defer.inlineCallbacks
+    def _get_event_ordering(self, event_id):
+        res = yield self._simple_select_one(
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True
+        )
+
+        if not res:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+
 
 AllNewEventsResult = namedtuple("AllNewEventsResult", [
     "new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index cbec255966..0a819d32c5 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes
 from twisted.internet import defer
 
 import logging
@@ -184,6 +185,18 @@ class PushRuleStore(SQLBaseStore):
             if uid in local_users_in_room:
                 user_ids.add(uid)
 
+        forgotten = yield self.who_forgot_in_room(
+            event.room_id, on_invalidate=cache_context.invalidate,
+        )
+
+        for row in forgotten:
+            user_id = row["user_id"]
+            event_id = row["event_id"]
+
+            mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
+            if event_id == mem_id:
+                user_ids.discard(user_id)
+
         rules_by_user = yield self.bulk_get_push_rules(
             user_ids, on_invalidate=cache_context.invalidate,
         )
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8cc9f0353b..34d2f82b7f 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -135,6 +135,48 @@ class PusherStore(SQLBaseStore):
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
+    def get_all_updated_pushers_rows(self, last_id, current_id, limit):
+        """Get all the pushers that have changed between the given tokens.
+
+        Returns:
+            Deferred(list(tuple)): each tuple consists of:
+                stream_id (str)
+                user_id (str)
+                app_id (str)
+                pushkey (str)
+                was_deleted (bool): whether the pusher was added/updated (False)
+                    or deleted (True)
+        """
+
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_pushers_rows_txn(txn):
+            sql = (
+                "SELECT id, user_name, app_id, pushkey"
+                " FROM pushers"
+                " WHERE ? < id AND id <= ?"
+                " ORDER BY id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            results = [list(row) + [False] for row in txn]
+
+            sql = (
+                "SELECT stream_id, user_id, app_id, pushkey"
+                " FROM deleted_pushers"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+
+            results.extend(list(row) + [True] for row in txn)
+            results.sort()  # Sort so that they're ordered by stream id
+
+            return results
+        return self.runInteraction(
+            "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
+        )
+
     @cachedInlineCallbacks(num_args=1, max_entries=15000)
     def get_if_user_has_pusher(self, user_id):
         # This only exists for the cachedList decorator
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 6b0f8c2787..efb90c3c91 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -47,10 +47,13 @@ class ReceiptsStore(SQLBaseStore):
         # Returns an ObservableDeferred
         res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
 
-        if res and res.called and user_id in res.result:
-            # We'd only be adding to the set, so no point invalidating if the
-            # user is already there
-            return
+        if res:
+            if isinstance(res, defer.Deferred) and res.called:
+                res = res.result
+            if user_id in res:
+                # We'd only be adding to the set, so no point invalidating if the
+                # user is already there
+                return
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index e4c56cc175..5d543652bb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 from ._base import SQLBaseStore
 from .engines import PostgresEngine, Sqlite3Engine
@@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple(
     ("ban_level", "kick_level", "redact_level",)
 )
 
+RatelimitOverride = collections.namedtuple(
+    "RatelimitOverride",
+    ("messages_per_second", "burst_count",)
+)
+
 
 class RoomStore(SQLBaseStore):
 
@@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore):
         return self.runInteraction(
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def get_ratelimit_for_user(self, user_id):
+        """Check if there are any overrides for ratelimiting for the given
+        user
+
+        Args:
+            user_id (str)
+
+        Returns:
+            RatelimitOverride if there is an override, else None. If the contents
+            of RatelimitOverride are None or 0 then ratelimitng has been
+            disabled for that user entirely.
+        """
+        row = yield self._simple_select_one(
+            table="ratelimit_override",
+            keyvalues={"user_id": user_id},
+            retcols=("messages_per_second", "burst_count"),
+            allow_none=True,
+            desc="get_ratelimit_for_user",
+        )
+
+        if row:
+            defer.returnValue(RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
+            ))
+        else:
+            defer.returnValue(None)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 367dbbbcf6..0829ae5bee 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -18,7 +18,9 @@ from twisted.internet import defer
 from collections import namedtuple
 
 from ._base import SQLBaseStore
+from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.stringutils import to_ascii
 
 from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
@@ -35,6 +37,13 @@ RoomsForUser = namedtuple(
 )
 
 
+# We store this using a namedtuple so that we save about 3x space over using a
+# dict.
+ProfileInfo = namedtuple(
+    "ProfileInfo", ("avatar_url", "display_name")
+)
+
+
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
 
@@ -139,7 +148,7 @@ class RoomMemberStore(SQLBaseStore):
         hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
         defer.returnValue(hosts)
 
-    @cached(max_entries=500000, iterable=True)
+    @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
             sql = (
@@ -152,7 +161,7 @@ class RoomMemberStore(SQLBaseStore):
             )
 
             txn.execute(sql, (room_id, Membership.JOIN,))
-            return [r[0] for r in txn]
+            return [to_ascii(r[0]) for r in txn]
         return self.runInteraction("get_users_in_room", f)
 
     @cached()
@@ -378,7 +387,9 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            event.room_id, state_group, context.current_state_ids, event=event,
+            event.room_id, state_group, context.current_state_ids,
+            event=event,
+            context=context,
         )
 
     def get_joined_users_from_state(self, room_id, state_group, state_ids):
@@ -396,46 +407,95 @@ class RoomMemberStore(SQLBaseStore):
     @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
                            max_entries=100000)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context, event=None):
+                                       cache_context, event=None, context=None):
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
         # See bulk_get_push_rules_for_room for how we work around this.
         assert state_group is not None
 
+        users_in_room = {}
         member_event_ids = [
             e_id
             for key, e_id in current_state_ids.iteritems()
             if key[0] == EventTypes.Member
         ]
 
-        rows = yield self._simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=member_event_ids,
-            retcols=['user_id', 'display_name', 'avatar_url'],
-            keyvalues={
-                "membership": Membership.JOIN,
-            },
-            batch_size=500,
-            desc="_get_joined_users_from_context",
+        if context is not None:
+            # If we have a context with a delta from a previous state group,
+            # check if we also have the result from the previous group in cache.
+            # If we do then we can reuse that result and simply update it with
+            # any membership changes in `delta_ids`
+            if context.prev_group and context.delta_ids:
+                prev_res = self._get_joined_users_from_context.cache.get(
+                    (room_id, context.prev_group), None
+                )
+                if prev_res and isinstance(prev_res, dict):
+                    users_in_room = dict(prev_res)
+                    member_event_ids = [
+                        e_id
+                        for key, e_id in context.delta_ids.iteritems()
+                        if key[0] == EventTypes.Member
+                    ]
+                    for etype, state_key in context.delta_ids:
+                        users_in_room.pop(state_key, None)
+
+        # We check if we have any of the member event ids in the event cache
+        # before we ask the DB
+
+        # We don't update the event cache hit ratio as it completely throws off
+        # the hit ratio counts. After all, we don't populate the cache if we
+        # miss it here
+        event_map = self._get_events_from_cache(
+            member_event_ids,
+            allow_rejected=False,
+            update_metrics=False,
         )
 
-        users_in_room = {
-            row["user_id"]: {
-                "display_name": row["display_name"],
-                "avatar_url": row["avatar_url"],
-            }
-            for row in rows
-        }
+        missing_member_event_ids = []
+        for event_id in member_event_ids:
+            ev_entry = event_map.get(event_id)
+            if ev_entry:
+                if ev_entry.event.membership == Membership.JOIN:
+                    users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(
+                            ev_entry.event.content.get("displayname", None)
+                        ),
+                        avatar_url=to_ascii(
+                            ev_entry.event.content.get("avatar_url", None)
+                        ),
+                    )
+            else:
+                missing_member_event_ids.append(event_id)
+
+        if missing_member_event_ids:
+            rows = yield self._simple_select_many_batch(
+                table="room_memberships",
+                column="event_id",
+                iterable=missing_member_event_ids,
+                retcols=('user_id', 'display_name', 'avatar_url',),
+                keyvalues={
+                    "membership": Membership.JOIN,
+                },
+                batch_size=500,
+                desc="_get_joined_users_from_context",
+            )
+
+            users_in_room.update({
+                to_ascii(row["user_id"]): ProfileInfo(
+                    avatar_url=to_ascii(row["avatar_url"]),
+                    display_name=to_ascii(row["display_name"]),
+                )
+                for row in rows
+            })
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = {
-                        "display_name": event.content.get("displayname", None),
-                        "avatar_url": event.content.get("avatar_url", None),
-                    }
+                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(event.content.get("displayname", None)),
+                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
+                    )
 
         defer.returnValue(users_in_room)
 
@@ -474,6 +534,45 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(False)
 
+    def get_joined_hosts(self, room_id, state_group, state_ids):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_hosts(
+            room_id, state_group, state_ids
+        )
+
+    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+    def _get_joined_hosts(self, room_id, state_group, current_state_ids):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        joined_hosts = set()
+        for etype, state_key in current_state_ids:
+            if etype == EventTypes.Member:
+                try:
+                    host = get_domain_from_id(state_key)
+                except:
+                    logger.warn("state_key not user_id: %s", state_key)
+                    continue
+
+                if host in joined_hosts:
+                    continue
+
+                event_id = current_state_ids[(etype, state_key)]
+                event = yield self.get_event(event_id, allow_none=True)
+                if event and event.content["membership"] == Membership.JOIN:
+                    joined_hosts.add(intern_string(host))
+
+        defer.returnValue(joined_hosts)
+
     @defer.inlineCallbacks
     def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py
index 784f3b348f..20ad8bd5a6 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/schema/delta/37/remove_auth_idx.py
@@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref;
 -- and is used incredibly rarely.
 DROP INDEX IF EXISTS events_order_topo_stream_room;
 
+-- an equivalent index to this actually gets re-created in delta 41, because it
+-- turned out that deleting it wasn't a great plan :/. In any case, let's
+-- delete it here, and delta 41 will create a new one with an added UNIQUE
+-- constraint
 DROP INDEX IF EXISTS event_search_ev_idx;
 """
 
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
new file mode 100644
index 0000000000..5d9cfecf36
--- /dev/null
+++ b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('event_search_event_id_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql
new file mode 100644
index 0000000000..a194bf0238
--- /dev/null
+++ b/synapse/storage/schema/delta/41/ratelimit.sql
@@ -0,0 +1,22 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE ratelimit_override (
+    user_id TEXT NOT NULL,
+    messages_per_second BIGINT,
+    burst_count BIGINT
+);
+
+CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index fb23f6f462..85acf2ad1e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,8 +14,9 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches import intern_string
+from synapse.util.stringutils import to_ascii
 from synapse.storage.engines import PostgresEngine
 
 from twisted.internet import defer
@@ -69,17 +70,33 @@ class StateStore(SQLBaseStore):
             where_clause="type='m.room.member'",
         )
 
-    @cachedInlineCallbacks(max_entries=100000, iterable=True)
+    @cached(max_entries=100000, iterable=True)
     def get_current_state_ids(self, room_id):
-        rows = yield self._simple_select_list(
-            table="current_state_events",
-            keyvalues={"room_id": room_id},
-            retcols=["event_id", "type", "state_key"],
-            desc="_calculate_state_delta",
+        """Get the current state event ids for a room based on the
+        current_state_events table.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            deferred: dict of (type, state_key) -> event_id
+        """
+        def _get_current_state_ids_txn(txn):
+            txn.execute(
+                """SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+                """,
+                (room_id,)
+            )
+
+            return {
+                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
+            }
+
+        return self.runInteraction(
+            "get_current_state_ids",
+            _get_current_state_ids_txn,
         )
-        defer.returnValue({
-            (r["type"], r["state_key"]): r["event_id"] for r in rows
-        })
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, room_id, event_ids):
@@ -210,6 +227,18 @@ class StateStore(SQLBaseStore):
                     ],
                 )
 
+            # Prefill the state group cache with this group.
+            # It's fine to use the sequence like this as the state group map
+            # is immutable. (If the map wasn't immutable then this prefill could
+            # race with another update)
+            txn.call_after(
+                self._state_group_cache.update,
+                self._state_group_cache.sequence,
+                key=context.state_group,
+                value=dict(context.current_state_ids),
+                full=True,
+            )
+
         self._simple_insert_many_txn(
             txn,
             table="event_to_state_groups",
@@ -263,12 +292,7 @@ class StateStore(SQLBaseStore):
 
             return count
 
-    @cached(num_args=2, max_entries=100000, iterable=True)
-    def _get_state_group_from_group(self, group, types):
-        raise NotImplementedError()
-
-    @cachedList(cached_method_name="_get_state_group_from_group",
-                list_name="groups", num_args=2, inlineCallbacks=True)
+    @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, types):
         """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
@@ -496,7 +520,7 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, max_entries=100000)
+    @cached(num_args=2, max_entries=50000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
@@ -644,7 +668,7 @@ class StateStore(SQLBaseStore):
                     state_dict = results[group]
 
                 state_dict.update(
-                    ((intern_string(k[0]), intern_string(k[1])), v)
+                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
                     for k, v in group_state_dict.iteritems()
                 )