summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/_base.py194
-rw-r--r--synapse/storage/background_updates.py2
-rw-r--r--synapse/storage/client_ips.py92
-rw-r--r--synapse/storage/e2e_room_keys.py21
-rw-r--r--synapse/storage/engines/__init__.py2
-rw-r--r--synapse/storage/engines/postgres.py14
-rw-r--r--synapse/storage/engines/sqlite.py (renamed from synapse/storage/engines/sqlite3.py)8
-rw-r--r--synapse/storage/event_federation.py23
-rw-r--r--synapse/storage/events.py169
-rw-r--r--synapse/storage/events_worker.py86
-rw-r--r--synapse/storage/monthly_active_users.py12
-rw-r--r--synapse/storage/pusher.py9
-rw-r--r--synapse/storage/roommember.py8
-rw-r--r--synapse/storage/schema/delta/53/event_format_version.sql16
-rw-r--r--synapse/storage/schema/delta/53/user_ips_index.sql10
-rw-r--r--synapse/storage/state.py42
-rw-r--r--synapse/storage/user_directory.py212
18 files changed, 698 insertions, 224 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 24329879e5..42cd3c83ad 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -317,7 +317,7 @@ class DataStore(RoomMemberStore, RoomStore,
                               thirty_days_ago_in_secs))
 
             for row in txn:
-                if row[0] is 'unknown':
+                if row[0] == 'unknown':
                     pass
                 results[row[0]] = row[1]
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 865b5e915a..e124161845 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,7 +26,8 @@ from prometheus_client import Histogram
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.storage.engines import PostgresEngine
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.util.caches.descriptors import Cache
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.stringutils import exception_to_unicode
@@ -49,6 +50,21 @@ sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
 sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
 
 
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+    "user_ips": "user_ips_device_unique_index",
+    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+    "event_search": "event_search_event_id_idx",
+}
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -192,6 +208,57 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
+        # A set of tables that are not safe to use native upserts in.
+        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+        # We add the user_directory_search table to the blacklist on SQLite
+        # because the existing search table does not have an index, making it
+        # unsafe to use native upserts.
+        if isinstance(self.database_engine, Sqlite3Engine):
+            self._unsafe_to_upsert_tables.add("user_directory_search")
+
+        if self.database_engine.can_native_upsert:
+            # Check ASAP (and then later, every 1s) to see if we have finished
+            # background updates of tables that aren't safe to update.
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert
+            )
+
+    @defer.inlineCallbacks
+    def _check_safe_to_upsert(self):
+        """
+        Is it safe to use native UPSERT?
+
+        If there are background updates, we will need to wait, as they may be
+        the addition of indexes that set the UNIQUE constraint that we require.
+
+        If the background updates have not completed, wait 15 sec and check again.
+        """
+        updates = yield self._simple_select_list(
+            "background_updates",
+            keyvalues=None,
+            retcols=["update_name"],
+            desc="check_background_updates",
+        )
+        updates = [x["update_name"] for x in updates]
+
+        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+            if update_name not in updates:
+                logger.debug("Now safe to upsert in %s", table)
+                self._unsafe_to_upsert_tables.discard(table)
+
+        # If there's any updates still running, reschedule to run.
+        if updates:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert
+            )
+
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -494,8 +561,15 @@ class SQLBaseStore(object):
         txn.executemany(sql, vals)
 
     @defer.inlineCallbacks
-    def _simple_upsert(self, table, keyvalues, values,
-                       insertion_values={}, desc="_simple_upsert", lock=True):
+    def _simple_upsert(
+        self,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        desc="_simple_upsert",
+        lock=True
+    ):
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -516,16 +590,21 @@ class SQLBaseStore(object):
                 inserting
             lock (bool): True to lock the table when doing the upsert.
         Returns:
-            Deferred(bool): True if a new entry was created, False if an
-                existing one was updated.
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
         """
         attempts = 0
         while True:
             try:
                 result = yield self.runInteraction(
                     desc,
-                    self._simple_upsert_txn, table, keyvalues, values, insertion_values,
-                    lock=lock
+                    self._simple_upsert_txn,
+                    table,
+                    keyvalues,
+                    values,
+                    insertion_values,
+                    lock=lock,
                 )
                 defer.returnValue(result)
             except self.database_engine.module.IntegrityError as e:
@@ -537,12 +616,71 @@ class SQLBaseStore(object):
 
                 # presumably we raced with another transaction: let's retry.
                 logger.warn(
-                    "IntegrityError when upserting into %s; retrying: %s",
-                    table, e
+                    "%s when upserting into %s; retrying: %s", e.__name__, table, e
                 )
 
-    def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
-                           lock=True):
+    def _simple_upsert_txn(
+        self,
+        txn,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        lock=True,
+    ):
+        """
+        Pick the UPSERT method which works best on the platform. Either the
+        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+        Args:
+            txn: The transaction to use.
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            None or bool: Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        if (
+            self.database_engine.can_native_upsert
+            and table not in self._unsafe_to_upsert_tables
+        ):
+            return self._simple_upsert_txn_native_upsert(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+            )
+        else:
+            return self._simple_upsert_txn_emulated(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                lock=lock,
+            )
+
+    def _simple_upsert_txn_emulated(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            bool: Return True if a new entry was created, False if an existing
+            one was updated.
+        """
         # We need to lock the table :(, unless we're *really* careful
         if lock:
             self.database_engine.lock_table(txn, table)
@@ -577,12 +715,44 @@ class SQLBaseStore(object):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
             table,
             ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues)
+            ", ".join("?" for _ in allvalues),
         )
         txn.execute(sql, list(allvalues.values()))
         # successfully inserted
         return True
 
+    def _simple_upsert_txn_native_upsert(
+        self, txn, table, keyvalues, values, insertion_values={}
+    ):
+        """
+        Use the native UPSERT functionality in recent PostgreSQL versions.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+        Returns:
+            None
+        """
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
+
+        sql = (
+            "INSERT INTO %s (%s) VALUES (%s) "
+            "ON CONFLICT (%s) DO UPDATE SET %s"
+        ) % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+            ", ".join(k for k in keyvalues),
+            ", ".join(k + "=EXCLUDED." + k for k in values),
+        )
+        txn.execute(sql, list(allvalues.values()))
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 5fe1ca2de7..60cdc884e6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -240,7 +240,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         * An integer count of the number of items to update in this batch.
 
         The handler should return a deferred integer count of items updated.
-        The hander is responsible for updating the progress of the update.
+        The handler is responsible for updating the progress of the update.
 
         Args:
             update_name(str): The name of the update that this code handles.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index b228a20ac2..9c21362226 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -66,6 +66,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         )
 
         self.register_background_update_handler(
+            "user_ips_analyze",
+            self._analyze_user_ip,
+        )
+
+        self.register_background_update_handler(
             "user_ips_remove_dupes",
             self._remove_user_ip_dupes,
         )
@@ -109,6 +114,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         defer.returnValue(1)
 
     @defer.inlineCallbacks
+    def _analyze_user_ip(self, progress, batch_size):
+        # Background update to analyze user_ips table before we run the
+        # deduplication background update. The table may not have been analyzed
+        # for ages due to the table locks.
+        #
+        # This will lock out the naive upserts to user_ips while it happens, but
+        # the analyze should be quick (28GB table takes ~10s)
+        def user_ips_analyze(txn):
+            txn.execute("ANALYZE user_ips")
+
+        yield self.runInteraction(
+            "user_ips_analyze", user_ips_analyze
+        )
+
+        yield self._end_background_update("user_ips_analyze")
+
+        defer.returnValue(1)
+
+    @defer.inlineCallbacks
     def _remove_user_ip_dupes(self, progress, batch_size):
         # This works function works by scanning the user_ips table in batches
         # based on `last_seen`. For each row in a batch it searches the rest of
@@ -167,12 +191,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 clause = "? <= last_seen AND last_seen < ?"
                 args = (begin_last_seen, end_last_seen)
 
+            # (Note: The DISTINCT in the inner query is important to ensure that
+            # the COUNT(*) is accurate, otherwise double counting may happen due
+            # to the join effectively being a cross product)
             txn.execute(
                 """
                 SELECT user_id, access_token, ip,
-                       MAX(device_id), MAX(user_agent), MAX(last_seen)
+                       MAX(device_id), MAX(user_agent), MAX(last_seen),
+                       COUNT(*)
                 FROM (
-                    SELECT user_id, access_token, ip
+                    SELECT DISTINCT user_id, access_token, ip
                     FROM user_ips
                     WHERE {}
                 ) c
@@ -186,7 +214,60 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
             # We've got some duplicates
             for i in res:
-                user_id, access_token, ip, device_id, user_agent, last_seen = i
+                user_id, access_token, ip, device_id, user_agent, last_seen, count = i
+
+                # We want to delete the duplicates so we end up with only a
+                # single row.
+                #
+                # The naive way of doing this would be just to delete all rows
+                # and reinsert a constructed row. However, if there are a lot of
+                # duplicate rows this can cause the table to grow a lot, which
+                # can be problematic in two ways:
+                #   1. If user_ips is already large then this can cause the
+                #      table to rapidly grow, potentially filling the disk.
+                #   2. Reinserting a lot of rows can confuse the table
+                #      statistics for postgres, causing it to not use the
+                #      correct indices for the query above, resulting in a full
+                #      table scan. This is incredibly slow for large tables and
+                #      can kill database performance. (This seems to mainly
+                #      happen for the last query where the clause is simply `? <
+                #      last_seen`)
+                #
+                # So instead we want to delete all but *one* of the duplicate
+                # rows. That is hard to do reliably, so we cheat and do a two
+                # step process:
+                #   1. Delete all rows with a last_seen strictly less than the
+                #      max last_seen. This hopefully results in deleting all but
+                #      one row the majority of the time, but there may be
+                #      duplicate last_seen
+                #   2. If multiple rows remain, we fall back to the naive method
+                #      and simply delete all rows and reinsert.
+                #
+                # Note that this relies on no new duplicate rows being inserted,
+                # but if that is happening then this entire process is futile
+                # anyway.
+
+                # Do step 1:
+
+                txn.execute(
+                    """
+                    DELETE FROM user_ips
+                    WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
+                    """,
+                    (user_id, access_token, ip, last_seen)
+                )
+                if txn.rowcount == count - 1:
+                    # We deleted all but one of the duplicate rows, i.e. there
+                    # is exactly one remaining and so there is nothing left to
+                    # do.
+                    continue
+                elif txn.rowcount >= count:
+                    raise Exception(
+                        "We deleted more duplicate rows from 'user_ips' than expected",
+                    )
+
+                # The previous step didn't delete enough rows, so we fallback to
+                # step 2:
 
                 # Drop all the duplicates
                 txn.execute(
@@ -257,7 +338,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
-        self.database_engine.lock_table(txn, "user_ips")
+        if "user_ips" in self._unsafe_to_upsert_tables or (
+            not self.database_engine.can_native_upsert
+        ):
+            self.database_engine.lock_table(txn, "user_ips")
 
         for entry in iteritems(to_update):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index 45cebe61d1..9a3aec759e 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -298,6 +298,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
         )
 
+    def update_e2e_room_keys_version(self, user_id, version, info):
+        """Update a given backup version
+
+        Args:
+            user_id(str): the user whose backup version we're updating
+            version(str): the version ID of the backup version we're updating
+            info(dict): the new backup version info to store
+        """
+
+        return self._simple_update(
+            table="e2e_room_keys_versions",
+            keyvalues={
+                "user_id": user_id,
+                "version": version,
+            },
+            updatevalues={
+                "auth_data": json.dumps(info["auth_data"]),
+            },
+            desc="update_e2e_room_keys_version"
+        )
+
     def delete_e2e_room_keys_version(self, user_id, version=None):
         """Delete a given backup version of the user's room keys.
         Doesn't delete their actual key data.
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index e2f9de8451..ff5ef97ca8 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,7 +18,7 @@ import platform
 
 from ._base import IncorrectDatabaseSetup
 from .postgres import PostgresEngine
-from .sqlite3 import Sqlite3Engine
+from .sqlite import Sqlite3Engine
 
 SUPPORTED_MODULE = {
     "sqlite3": Sqlite3Engine,
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 42225f8a2a..4004427c7b 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -38,6 +38,13 @@ class PostgresEngine(object):
         return sql.replace("?", "%s")
 
     def on_new_connection(self, db_conn):
+
+        # Get the version of PostgreSQL that we're using. As per the psycopg2
+        # docs: The number is formed by converting the major, minor, and
+        # revision numbers into two-decimal-digit numbers and appending them
+        # together. For example, version 8.1.5 will be returned as 80105
+        self._version = db_conn.server_version
+
         db_conn.set_isolation_level(
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
@@ -54,6 +61,13 @@ class PostgresEngine(object):
 
         cursor.close()
 
+    @property
+    def can_native_upsert(self):
+        """
+        Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+        """
+        return self._version >= 90500
+
     def is_deadlock(self, error):
         if isinstance(error, self.module.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite.py
index 19949fc474..059ab81055 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite.py
@@ -30,6 +30,14 @@ class Sqlite3Engine(object):
         self._current_state_group_id = None
         self._current_state_group_id_lock = threading.Lock()
 
+    @property
+    def can_native_upsert(self):
+        """
+        Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
+        more work we haven't done yet to tell what was inserted vs updated.
+        """
+        return self.module.sqlite_version_info >= (3, 24, 0)
+
     def check_database(self, txn):
         pass
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index d3b9dea1d6..38809ed0fc 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
 
         return dict(txn)
 
+    @defer.inlineCallbacks
+    def get_max_depth_of(self, event_ids):
+        """Returns the max depth of a set of event IDs
+
+        Args:
+            event_ids (list[str])
+
+        Returns
+            Deferred[int]
+        """
+        rows = yield self._simple_select_many_batch(
+            table="events",
+            column="event_id",
+            iterable=event_ids,
+            retcols=("depth",),
+            desc="get_max_depth_of",
+        )
+
+        if not rows:
+            defer.returnValue(0)
+        else:
+            defer.returnValue(max(row["depth"] for row in rows))
+
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
             txn,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 79e0276de6..81b250480d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -904,106 +904,106 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
     def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
         for room_id, current_state_tuple in iteritems(state_delta_by_room):
-                to_delete, to_insert = current_state_tuple
-
-                # First we add entries to the current_state_delta_stream. We
-                # do this before updating the current_state_events table so
-                # that we can use it to calculate the `prev_event_id`. (This
-                # allows us to not have to pull out the existing state
-                # unnecessarily).
-                sql = """
-                    INSERT INTO current_state_delta_stream
-                    (stream_id, room_id, type, state_key, event_id, prev_event_id)
-                    SELECT ?, ?, ?, ?, ?, (
-                        SELECT event_id FROM current_state_events
-                        WHERE room_id = ? AND type = ? AND state_key = ?
-                    )
-                """
-                txn.executemany(sql, (
-                    (
-                        max_stream_order, room_id, etype, state_key, None,
-                        room_id, etype, state_key,
-                    )
-                    for etype, state_key in to_delete
-                    # We sanity check that we're deleting rather than updating
-                    if (etype, state_key) not in to_insert
-                ))
-                txn.executemany(sql, (
-                    (
-                        max_stream_order, room_id, etype, state_key, ev_id,
-                        room_id, etype, state_key,
-                    )
-                    for (etype, state_key), ev_id in iteritems(to_insert)
-                ))
+            to_delete, to_insert = current_state_tuple
 
-                # Now we actually update the current_state_events table
-
-                txn.executemany(
-                    "DELETE FROM current_state_events"
-                    " WHERE room_id = ? AND type = ? AND state_key = ?",
-                    (
-                        (room_id, etype, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
+            # First we add entries to the current_state_delta_stream. We
+            # do this before updating the current_state_events table so
+            # that we can use it to calculate the `prev_event_id`. (This
+            # allows us to not have to pull out the existing state
+            # unnecessarily).
+            sql = """
+                INSERT INTO current_state_delta_stream
+                (stream_id, room_id, type, state_key, event_id, prev_event_id)
+                SELECT ?, ?, ?, ?, ?, (
+                    SELECT event_id FROM current_state_events
+                    WHERE room_id = ? AND type = ? AND state_key = ?
                 )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="current_state_events",
-                    values=[
-                        {
-                            "event_id": ev_id,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                        }
-                        for key, ev_id in iteritems(to_insert)
-                    ],
+            """
+            txn.executemany(sql, (
+                (
+                    max_stream_order, room_id, etype, state_key, None,
+                    room_id, etype, state_key,
                 )
-
-                txn.call_after(
-                    self._curr_state_delta_stream_cache.entity_has_changed,
-                    room_id, max_stream_order,
+                for etype, state_key in to_delete
+                # We sanity check that we're deleting rather than updating
+                if (etype, state_key) not in to_insert
+            ))
+            txn.executemany(sql, (
+                (
+                    max_stream_order, room_id, etype, state_key, ev_id,
+                    room_id, etype, state_key,
                 )
+                for (etype, state_key), ev_id in iteritems(to_insert)
+            ))
 
-                # Invalidate the various caches
-
-                # Figure out the changes of membership to invalidate the
-                # `get_rooms_for_user` cache.
-                # We find out which membership events we may have deleted
-                # and which we have added, then we invlidate the caches for all
-                # those users.
-                members_changed = set(
-                    state_key
-                    for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                    if ev_type == EventTypes.Member
-                )
+            # Now we actually update the current_state_events table
 
-                for member in members_changed:
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_rooms_for_user_with_stream_ordering, (member,)
-                    )
+            txn.executemany(
+                "DELETE FROM current_state_events"
+                " WHERE room_id = ? AND type = ? AND state_key = ?",
+                (
+                    (room_id, etype, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
 
-                for host in set(get_domain_from_id(u) for u in members_changed):
-                    self._invalidate_cache_and_stream(
-                        txn, self.is_host_joined, (room_id, host)
-                    )
-                    self._invalidate_cache_and_stream(
-                        txn, self.was_host_joined, (room_id, host)
-                    )
+            self._simple_insert_many_txn(
+                txn,
+                table="current_state_events",
+                values=[
+                    {
+                        "event_id": ev_id,
+                        "room_id": room_id,
+                        "type": key[0],
+                        "state_key": key[1],
+                    }
+                    for key, ev_id in iteritems(to_insert)
+                ],
+            )
+
+            txn.call_after(
+                self._curr_state_delta_stream_cache.entity_has_changed,
+                room_id, max_stream_order,
+            )
+
+            # Invalidate the various caches
+
+            # Figure out the changes of membership to invalidate the
+            # `get_rooms_for_user` cache.
+            # We find out which membership events we may have deleted
+            # and which we have added, then we invlidate the caches for all
+            # those users.
+            members_changed = set(
+                state_key
+                for ev_type, state_key in itertools.chain(to_delete, to_insert)
+                if ev_type == EventTypes.Member
+            )
 
+            for member in members_changed:
                 self._invalidate_cache_and_stream(
-                    txn, self.get_users_in_room, (room_id,)
+                    txn, self.get_rooms_for_user_with_stream_ordering, (member,)
                 )
 
+            for host in set(get_domain_from_id(u) for u in members_changed):
                 self._invalidate_cache_and_stream(
-                    txn, self.get_room_summary, (room_id,)
+                    txn, self.is_host_joined, (room_id, host)
                 )
-
                 self._invalidate_cache_and_stream(
-                    txn, self.get_current_state_ids, (room_id,)
+                    txn, self.was_host_joined, (room_id, host)
                 )
 
+            self._invalidate_cache_and_stream(
+                txn, self.get_users_in_room, (room_id,)
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_room_summary, (room_id,)
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_current_state_ids, (room_id,)
+            )
+
     def _update_forward_extremities_txn(self, txn, new_forward_extremities,
                                         max_stream_order):
         for room_id, new_extrem in iteritems(new_forward_extremities):
@@ -1268,6 +1268,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                         event.internal_metadata.get_dict()
                     ),
                     "json": encode_json(event_dict(event)),
+                    "format_version": event.format_version,
                 }
                 for event, _ in events_and_contexts
             ],
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index a8326f5296..1716be529a 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -21,13 +21,14 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
+from synapse.api.constants import EventFormatVersions, EventTypes
 from synapse.api.errors import NotFoundError
+from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
 # these are only included to make the type annotations work
-from synapse.events import EventBase  # noqa: F401
-from synapse.events import FrozenEvent
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.events.utils import prune_event
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import get_domain_from_id
 from synapse.util.logcontext import (
     LoggingContext,
     PreserveLoggingContext,
@@ -160,9 +161,14 @@ class EventsWorkerStore(SQLBaseStore):
             log_ctx = LoggingContext.current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
+            # Note that _enqueue_events is also responsible for turning db rows
+            # into FrozenEvents (via _get_event_from_row), which involves seeing if
+            # the events have been redacted, and if so pulling the redaction event out
+            # of the database to check it.
+            #
+            # _enqueue_events is a bit of a rubbish name but naming is hard.
             missing_events = yield self._enqueue_events(
                 missing_events_ids,
-                check_redacted=check_redacted,
                 allow_rejected=allow_rejected,
             )
 
@@ -174,6 +180,50 @@ class EventsWorkerStore(SQLBaseStore):
             if not entry:
                 continue
 
+            # Starting in room version v3, some redactions need to be rechecked if we
+            # didn't have the redacted event at the time, so we recheck on read
+            # instead.
+            if not allow_rejected and entry.event.type == EventTypes.Redaction:
+                if entry.event.internal_metadata.need_to_check_redaction():
+                    # XXX: we need to avoid calling get_event here.
+                    #
+                    # The problem is that we end up at this point when an event
+                    # which has been redacted is pulled out of the database by
+                    # _enqueue_events, because _enqueue_events needs to check the
+                    # redaction before it can cache the redacted event. So obviously,
+                    # calling get_event to get the redacted event out of the database
+                    # gives us an infinite loop.
+                    #
+                    # For now (quick hack to fix during 0.99 release cycle), we just
+                    # go and fetch the relevant row from the db, but it would be nice
+                    # to think about how we can cache this rather than hit the db
+                    # every time we access a redaction event.
+                    #
+                    # One thought on how to do this:
+                    #  1. split _get_events up so that it is divided into (a) get the
+                    #     rawish event from the db/cache, (b) do the redaction/rejection
+                    #     filtering
+                    #  2. have _get_event_from_row just call the first half of that
+
+                    orig_sender = yield self._simple_select_one_onecol(
+                        table="events",
+                        keyvalues={"event_id": entry.event.redacts},
+                        retcol="sender",
+                        allow_none=True,
+                    )
+
+                    expected_domain = get_domain_from_id(entry.event.sender)
+                    if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
+                        # This redaction event is allowed. Mark as not needing a
+                        # recheck.
+                        entry.event.internal_metadata.recheck_redaction = False
+                    else:
+                        # We don't have the event that is being redacted, so we
+                        # assume that the event isn't authorized for now. (If we
+                        # later receive the event, then we will always redact
+                        # it anyway, since we have this redaction)
+                        continue
+
             if allow_rejected or not entry.event.rejected_reason:
                 if check_redacted and entry.redacted_event:
                     event = entry.redacted_event
@@ -197,7 +247,7 @@ class EventsWorkerStore(SQLBaseStore):
         defer.returnValue(events)
 
     def _invalidate_get_event_cache(self, event_id):
-            self._get_event_cache.invalidate((event_id,))
+        self._get_event_cache.invalidate((event_id,))
 
     def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
         """Fetch events from the caches
@@ -310,7 +360,7 @@ class EventsWorkerStore(SQLBaseStore):
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
     @defer.inlineCallbacks
-    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+    def _enqueue_events(self, events, allow_rejected=False):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -353,6 +403,7 @@ class EventsWorkerStore(SQLBaseStore):
                     self._get_event_from_row,
                     row["internal_metadata"], row["json"], row["redacts"],
                     rejected_reason=row["rejects"],
+                    format_version=row["format_version"],
                 )
                 for row in rows
             ],
@@ -377,6 +428,7 @@ class EventsWorkerStore(SQLBaseStore):
                 " e.event_id as event_id, "
                 " e.internal_metadata,"
                 " e.json,"
+                " e.format_version, "
                 " r.redacts as redacts,"
                 " rej.event_id as rejects "
                 " FROM event_json as e"
@@ -392,7 +444,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def _get_event_from_row(self, internal_metadata, js, redacted,
-                            rejected_reason=None):
+                            format_version, rejected_reason=None):
         with Measure(self._clock, "_get_event_from_row"):
             d = json.loads(js)
             internal_metadata = json.loads(internal_metadata)
@@ -405,8 +457,13 @@ class EventsWorkerStore(SQLBaseStore):
                     desc="_get_event_from_row_rejected_reason",
                 )
 
-            original_ev = FrozenEvent(
-                d,
+            if format_version is None:
+                # This means that we stored the event before we had the concept
+                # of a event format version, so it must be a V1 event.
+                format_version = EventFormatVersions.V1
+
+            original_ev = event_type_from_format_version(format_version)(
+                event_dict=d,
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
@@ -436,6 +493,19 @@ class EventsWorkerStore(SQLBaseStore):
                     # will serialise this field correctly
                     redacted_event.unsigned["redacted_because"] = because
 
+                    # Starting in room version v3, some redactions need to be
+                    # rechecked if we didn't have the redacted event at the
+                    # time, so we recheck on read instead.
+                    if because.internal_metadata.need_to_check_redaction():
+                        expected_domain = get_domain_from_id(original_ev.sender)
+                        if get_domain_from_id(because.sender) == expected_domain:
+                            # This redaction event is allowed. Mark as not needing a
+                            # recheck.
+                            because.internal_metadata.recheck_redaction = False
+                        else:
+                            # Senders don't match, so the event isn't actually redacted
+                            redacted_event = None
+
             cache_entry = _EventCacheEntry(
                 event=original_ev,
                 redacted_event=redacted_event,
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d6fc8edd4c..9e7e09b8c1 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -197,15 +197,21 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         if is_support:
             return
 
-        is_insert = yield self.runInteraction(
+        yield self.runInteraction(
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
             user_id
         )
 
-        if is_insert:
-            self.user_last_seen_monthly_active.invalidate((user_id,))
+        user_in_mau = self.user_last_seen_monthly_active.cache.get(
+            (user_id,),
+            None,
+            update_metrics=False
+        )
+        if user_in_mau is None:
             self.get_monthly_active_count.invalidate(())
 
+        self.user_last_seen_monthly_active.invalidate((user_id,))
+
     def upsert_monthly_active_user_txn(self, txn, user_id):
         """Updates or inserts monthly active user member
 
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 2743b52bad..134297e284 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
         with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so _simple_upsert will retry
-            newly_inserted = yield self._simple_upsert(
+            yield self._simple_upsert(
                 table="pushers",
                 keyvalues={
                     "app_id": app_id,
@@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
                 lock=False,
             )
 
-            if newly_inserted:
+            user_has_pusher = self.get_if_user_has_pusher.cache.get(
+                (user_id,), None, update_metrics=False
+            )
+
+            if user_has_pusher is not True:
+                # invalidate, since we the user might not have had a pusher before
                 yield self.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0707f9a86a..592c1bcd33 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -588,12 +588,12 @@ class RoomMemberStore(RoomMemberWorkerStore):
             )
 
             # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened.
-            # The only current event that can also be an outlier is if its an
-            # invite that has come in across federation.
+            # i.e., its something that has just happened. If the event is an
+            # outlier it is only current if its an "out of band membership",
+            # like a remote invite or a rejection of a remote invite.
             is_new_state = not backfilled and (
                 not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_invite_from_remote()
+                or event.internal_metadata.is_out_of_band_membership()
             )
             is_mine = self.hs.is_mine_id(event.state_key)
             if is_new_state and is_mine:
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/schema/delta/53/event_format_version.sql
new file mode 100644
index 0000000000..1d977c2834
--- /dev/null
+++ b/synapse/storage/schema/delta/53/event_format_version.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_json ADD COLUMN format_version INTEGER;
diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/schema/delta/53/user_ips_index.sql
index 4ca346c111..b812c5794f 100644
--- a/synapse/storage/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/schema/delta/53/user_ips_index.sql
@@ -13,9 +13,13 @@
  * limitations under the License.
  */
 
--- delete duplicates
+ -- analyze user_ips, to help ensure the correct indices are used
 INSERT INTO background_updates (update_name, progress_json) VALUES
-  ('user_ips_remove_dupes', '{}');
+  ('user_ips_analyze', '{}');
+
+-- delete duplicates
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+  ('user_ips_remove_dupes', '{}', 'user_ips_analyze');
 
 -- add a new unique index to user_ips table
 INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
@@ -23,4 +27,4 @@ INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
 
 -- drop the old original index
 INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
-  ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index');
\ No newline at end of file
+  ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a134e9b3e8..d14a7b2538 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -428,14 +428,54 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         # for now we do this by looking at the create event. We may want to cache this
         # more intelligently in future.
+
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
+        defer.returnValue(create_event.content.get("room_version", "1"))
+
+    @defer.inlineCallbacks
+    def get_room_predecessor(self, room_id):
+        """Get the predecessor room of an upgraded room if one exists.
+        Otherwise return None.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[unicode|None]: predecessor room id
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
+
+        # Return predecessor if present
+        defer.returnValue(create_event.content.get("predecessor", None))
+
+    @defer.inlineCallbacks
+    def get_create_event_for_room(self, room_id):
+        """Get the create state event for a room.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[EventBase]: The room creation event.
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
         state_ids = yield self.get_current_state_ids(room_id)
         create_id = state_ids.get((EventTypes.Create, ""))
 
+        # If we can't find the create event, assume we've hit a dead end
         if not create_id:
             raise NotFoundError("Unknown room %s" % (room_id))
 
+        # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)
-        defer.returnValue(create_event.content.get("room_version", "1"))
+        defer.returnValue(create_event)
 
     @cached(max_entries=100000, iterable=True)
     def get_current_state_ids(self, room_id):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index a8781b0e5d..fea866c043 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.state import StateFilter
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -31,12 +32,19 @@ logger = logging.getLogger(__name__)
 
 
 class UserDirectoryStore(SQLBaseStore):
-    @cachedInlineCallbacks(cache_context=True)
-    def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+    @defer.inlineCallbacks
+    def is_room_world_readable_or_publicly_joinable(self, room_id):
         """Check if the room is either world_readable or publically joinable
         """
-        current_state_ids = yield self.get_current_state_ids(
-            room_id, on_invalidate=cache_context.invalidate
+
+        # Create a state filter that only queries join and history state event
+        types_to_filter = (
+            (EventTypes.JoinRules, ""),
+            (EventTypes.RoomHistoryVisibility, ""),
+        )
+
+        current_state_ids = yield self.get_filtered_current_state_ids(
+            room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
@@ -66,14 +74,8 @@ class UserDirectoryStore(SQLBaseStore):
         """
         yield self._simple_insert_many(
             table="users_in_public_rooms",
-            values=[
-                {
-                    "user_id": user_id,
-                    "room_id": room_id,
-                }
-                for user_id in user_ids
-            ],
-            desc="add_users_to_public_room"
+            values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
+            desc="add_users_to_public_room",
         )
         for user_id in user_ids:
             self.get_user_in_public_room.invalidate((user_id,))
@@ -99,7 +101,9 @@ class UserDirectoryStore(SQLBaseStore):
             """
             args = (
                 (
-                    user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+                    user_id,
+                    get_localpart_from_id(user_id),
+                    get_domain_from_id(user_id),
                     profile.display_name,
                 )
                 for user_id, profile in iteritems(users_with_profile)
@@ -112,7 +116,7 @@ class UserDirectoryStore(SQLBaseStore):
             args = (
                 (
                     user_id,
-                    "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+                    "%s %s" % (user_id, p.display_name) if p.display_name else user_id,
                 )
                 for user_id, p in iteritems(users_with_profile)
             )
@@ -133,12 +137,10 @@ class UserDirectoryStore(SQLBaseStore):
                         "avatar_url": profile.avatar_url,
                     }
                     for user_id, profile in iteritems(users_with_profile)
-                ]
+                ],
             )
             for user_id in users_with_profile:
-                txn.call_after(
-                    self.get_user_in_directory.invalidate, (user_id,)
-                )
+                txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
         return self.runInteraction(
             "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
@@ -168,39 +170,69 @@ class UserDirectoryStore(SQLBaseStore):
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
-                if new_entry:
+                if self.database_engine.can_native_upsert:
                     sql = """
                         INSERT INTO user_directory_search(user_id, vector)
                         VALUES (?,
                             setweight(to_tsvector('english', ?), 'A')
                             || setweight(to_tsvector('english', ?), 'D')
                             || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
-                        )
+                        ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
                     """
                     txn.execute(
                         sql,
                         (
-                            user_id, get_localpart_from_id(user_id),
-                            get_domain_from_id(user_id), display_name,
-                        )
+                            user_id,
+                            get_localpart_from_id(user_id),
+                            get_domain_from_id(user_id),
+                            display_name,
+                        ),
                     )
                 else:
-                    sql = """
-                        UPDATE user_directory_search
-                        SET vector = setweight(to_tsvector('english', ?), 'A')
-                            || setweight(to_tsvector('english', ?), 'D')
-                            || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
-                        WHERE user_id = ?
-                    """
-                    txn.execute(
-                        sql,
-                        (
-                            get_localpart_from_id(user_id), get_domain_from_id(user_id),
-                            display_name, user_id,
+                    # TODO: Remove this code after we've bumped the minimum version
+                    # of postgres to always support upserts, so we can get rid of
+                    # `new_entry` usage
+                    if new_entry is True:
+                        sql = """
+                            INSERT INTO user_directory_search(user_id, vector)
+                            VALUES (?,
+                                setweight(to_tsvector('english', ?), 'A')
+                                || setweight(to_tsvector('english', ?), 'D')
+                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            )
+                        """
+                        txn.execute(
+                            sql,
+                            (
+                                user_id,
+                                get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id),
+                                display_name,
+                            ),
+                        )
+                    elif new_entry is False:
+                        sql = """
+                            UPDATE user_directory_search
+                            SET vector = setweight(to_tsvector('english', ?), 'A')
+                                || setweight(to_tsvector('english', ?), 'D')
+                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            WHERE user_id = ?
+                        """
+                        txn.execute(
+                            sql,
+                            (
+                                get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id),
+                                display_name,
+                                user_id,
+                            ),
+                        )
+                    else:
+                        raise RuntimeError(
+                            "upsert returned None when 'can_native_upsert' is False"
                         )
-                    )
             elif isinstance(self.database_engine, Sqlite3Engine):
-                value = "%s %s" % (user_id, display_name,) if display_name else user_id
+                value = "%s %s" % (user_id, display_name) if display_name else user_id
                 self._simple_upsert_txn(
                     txn,
                     table="user_directory_search",
@@ -231,29 +263,18 @@ class UserDirectoryStore(SQLBaseStore):
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
             self._simple_delete_txn(
-                txn,
-                table="user_directory",
-                keyvalues={"user_id": user_id},
+                txn, table="user_directory", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="user_directory_search",
-                keyvalues={"user_id": user_id},
+                txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="users_in_public_rooms",
-                keyvalues={"user_id": user_id},
-            )
-            txn.call_after(
-                self.get_user_in_directory.invalidate, (user_id,)
+                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
-            txn.call_after(
-                self.get_user_in_public_room.invalidate, (user_id,)
-            )
-        return self.runInteraction(
-            "remove_from_user_dir", _remove_from_user_dir_txn,
-        )
+            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+            txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
+
+        return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
 
     @defer.inlineCallbacks
     def remove_from_user_in_public_room(self, user_id):
@@ -338,6 +359,7 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _add_users_who_share_room_txn(txn):
             self._simple_insert_many_txn(
                 txn,
@@ -354,13 +376,12 @@ class UserDirectoryStore(SQLBaseStore):
             )
             for user_id, other_user_id in user_id_tuples:
                 txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate,
-                    (user_id,),
+                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
                 txn.call_after(
-                    self.get_if_users_share_a_room.invalidate,
-                    (user_id, other_user_id),
+                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
                 )
+
         return self.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
@@ -374,6 +395,7 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _update_users_who_share_room_txn(txn):
             sql = """
                 UPDATE users_who_share_rooms
@@ -381,21 +403,16 @@ class UserDirectoryStore(SQLBaseStore):
                 WHERE user_id = ? AND other_user_id = ?
             """
             txn.executemany(
-                sql,
-                (
-                    (room_id, share_private, uid, oid)
-                    for uid, oid in user_id_sets
-                )
+                sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
             )
             for user_id, other_user_id in user_id_sets:
                 txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate,
-                    (user_id,),
+                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
                 txn.call_after(
-                    self.get_if_users_share_a_room.invalidate,
-                    (user_id, other_user_id),
+                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
                 )
+
         return self.runInteraction(
             "update_users_who_share_room", _update_users_who_share_room_txn
         )
@@ -409,22 +426,18 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _remove_user_who_share_room_txn(txn):
             self._simple_delete_txn(
                 txn,
                 table="users_who_share_rooms",
-                keyvalues={
-                    "user_id": user_id,
-                    "other_user_id": other_user_id,
-                },
+                keyvalues={"user_id": user_id, "other_user_id": other_user_id},
             )
             txn.call_after(
-                self.get_users_who_share_room_from_dir.invalidate,
-                (user_id,),
+                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
             )
             txn.call_after(
-                self.get_if_users_share_a_room.invalidate,
-                (user_id, other_user_id),
+                self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
             )
 
         return self.runInteraction(
@@ -445,10 +458,7 @@ class UserDirectoryStore(SQLBaseStore):
         """
         return self._simple_select_one_onecol(
             table="users_who_share_rooms",
-            keyvalues={
-                "user_id": user_id,
-                "other_user_id": other_user_id,
-            },
+            keyvalues={"user_id": user_id, "other_user_id": other_user_id},
             retcol="share_private",
             allow_none=True,
             desc="get_if_users_share_a_room",
@@ -466,17 +476,12 @@ class UserDirectoryStore(SQLBaseStore):
         """
         rows = yield self._simple_select_list(
             table="users_who_share_rooms",
-            keyvalues={
-                "user_id": user_id,
-            },
-            retcols=("other_user_id", "share_private",),
+            keyvalues={"user_id": user_id},
+            retcols=("other_user_id", "share_private"),
             desc="get_users_who_share_room_with_user",
         )
 
-        defer.returnValue({
-            row["other_user_id"]: row["share_private"]
-            for row in rows
-        })
+        defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
 
     def get_users_in_share_dir_with_room_id(self, user_id, room_id):
         """Get all user tuples that are in the users_who_share_rooms due to the
@@ -523,6 +528,7 @@ class UserDirectoryStore(SQLBaseStore):
     def delete_all_from_user_dir(self):
         """Delete the entire user directory
         """
+
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
@@ -532,6 +538,7 @@ class UserDirectoryStore(SQLBaseStore):
             txn.call_after(self.get_user_in_public_room.invalidate_all)
             txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
             txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
@@ -541,7 +548,7 @@ class UserDirectoryStore(SQLBaseStore):
         return self._simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
-            retcols=("room_id", "display_name", "avatar_url",),
+            retcols=("room_id", "display_name", "avatar_url"),
             allow_none=True,
             desc="get_user_in_directory",
         )
@@ -574,7 +581,9 @@ class UserDirectoryStore(SQLBaseStore):
 
     def get_current_state_deltas(self, prev_stream_id):
         prev_stream_id = int(prev_stream_id)
-        if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+        if not self._curr_state_delta_stream_cache.has_any_entity_changed(
+            prev_stream_id
+        ):
             return []
 
         def get_current_state_deltas_txn(txn):
@@ -608,7 +617,7 @@ class UserDirectoryStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
             """
-            txn.execute(sql, (prev_stream_id, max_stream_id,))
+            txn.execute(sql, (prev_stream_id, max_stream_id))
             return self.cursor_to_dict(txn)
 
         return self.runInteraction(
@@ -698,8 +707,11 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """ % (join_clause, where_clause)
-            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
+            """ % (
+                join_clause,
+                where_clause,
+            )
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
@@ -716,7 +728,10 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """ % (join_clause, where_clause)
+            """ % (
+                join_clause,
+                where_clause,
+            )
             args = join_args + (search_query, limit + 1)
         else:
             # This should be unreachable.
@@ -728,10 +743,7 @@ class UserDirectoryStore(SQLBaseStore):
 
         limited = len(results) > limit
 
-        defer.returnValue({
-            "limited": limited,
-            "results": results,
-        })
+        defer.returnValue({"limited": limited, "results": results})
 
 
 def _parse_query_sqlite(search_term):
@@ -746,7 +758,7 @@ def _parse_query_sqlite(search_term):
 
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
-    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
+    return " & ".join("(%s* OR %s)" % (result, result) for result in results)
 
 
 def _parse_query_postgres(search_term):
@@ -759,7 +771,7 @@ def _parse_query_postgres(search_term):
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
-    both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+    both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
     exact = " & ".join("%s" % (result,) for result in results)
     prefix = " & ".join("%s:*" % (result,) for result in results)