summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/__init__.py5
-rw-r--r--synapse/storage/data_stores/main/cache.py44
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py88
-rw-r--r--synapse/storage/data_stores/main/devices.py211
-rw-r--r--synapse/storage/data_stores/main/directory.py26
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py3
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py14
-rw-r--r--synapse/storage/data_stores/main/events.py114
-rw-r--r--synapse/storage/data_stores/main/events_worker.py118
-rw-r--r--synapse/storage/data_stores/main/media_repository.py4
-rw-r--r--synapse/storage/data_stores/main/presence.py23
-rw-r--r--synapse/storage/data_stores/main/push_rule.py1
-rw-r--r--synapse/storage/data_stores/main/room.py40
-rw-r--r--synapse/storage/database.py11
14 files changed, 396 insertions, 306 deletions
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
             db_conn,
             "device_lists_stream",
             "stream_id",
-            extra_tables=[("user_signature_stream", "stream_id")],
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         self._cross_signing_id_gen = StreamIdGenerator(
             db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
 CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+    def get_all_updated_caches(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_caches_txn(txn):
+            # We purposefully don't bound by the current token, as we want to
+            # send across cache invalidations as quickly as possible. Cache
+            # invalidations are idempotent, so duplicates are fine.
+            sql = (
+                "SELECT stream_id, cache_func, keys, invalidation_ts"
+                " FROM cache_invalidation_stream"
+                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
                 },
             )
 
-    def get_all_updated_caches(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_caches_txn(txn):
-            # We purposefully don't bound by the current token, as we want to
-            # send across cache invalidations as quickly as possible. Cache
-            # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
-            return txn.fetchall()
-
-        return self.db.runInteraction(
-            "get_all_updated_caches", get_all_updated_caches_txn
-        )
-
     def get_cache_stream_token(self):
         if self._cache_id_gen:
             return self._cache_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
+    def get_all_new_device_messages(self, last_pos, current_pos, limit):
+        """
+        Args:
+            last_pos(int):
+            current_pos(int):
+            limit(int):
+        Returns:
+            A deferred list of rows from the device inbox
+        """
+        if last_pos == current_pos:
+            return defer.succeed([])
+
+        def get_all_new_device_messages_txn(txn):
+            # We limit like this as we might have multiple rows per stream_id, and
+            # we want to make sure we always get all entries for any stream_id
+            # we return.
+            upper_pos = min(current_pos, last_pos + limit)
+            sql = (
+                "SELECT max(stream_id), user_id"
+                " FROM device_inbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY user_id"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows = txn.fetchall()
+
+            sql = (
+                "SELECT max(stream_id), destination"
+                " FROM device_federation_outbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY destination"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows.extend(txn)
+
+            # Order by ascending stream ordering
+            rows.sort()
+
+            return rows
+
+        return self.db.runInteraction(
+            "get_all_new_device_messages", get_all_new_device_messages_txn
+        )
+
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((user_id, device_id, stream_id, message_json))
 
         txn.executemany(sql, rows)
-
-    def get_all_new_device_messages(self, last_pos, current_pos, limit):
-        """
-        Args:
-            last_pos(int):
-            current_pos(int):
-            limit(int):
-        Returns:
-            A deferred list of rows from the device inbox
-        """
-        if last_pos == current_pos:
-            return defer.succeed([])
-
-        def get_all_new_device_messages_txn(txn):
-            # We limit like this as we might have multiple rows per stream_id, and
-            # we want to make sure we always get all entries for any stream_id
-            # we return.
-            upper_pos = min(current_pos, last_pos + limit)
-            sql = (
-                "SELECT max(stream_id), user_id"
-                " FROM device_inbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY user_id"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows = txn.fetchall()
-
-            sql = (
-                "SELECT max(stream_id), destination"
-                " FROM device_federation_outbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY destination"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn)
-
-            # Order by ascending stream ordering
-            rows.sort()
-
-            return rows
-
-        return self.db.runInteraction(
-            "get_all_new_device_messages", get_all_new_device_messages_txn
-        )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 8af5f7de54..20995e1b78 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Tuple
 
 from six import iteritems
 
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
 from synapse.types import Collection, get_verify_key_from_cross_signing_key
 from synapse.util.caches.descriptors import (
     Cache,
@@ -40,6 +41,7 @@ from synapse.util.caches.descriptors import (
     cachedList,
 )
 from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
 
 logger = logging.getLogger(__name__)
 
@@ -112,23 +114,13 @@ class DeviceWorkerStore(SQLBaseStore):
         if not has_changed:
             return now_stream_id, []
 
-        # We retrieve n+1 devices from the list of outbound pokes where n is
-        # our outbound device update limit. We then check if the very last
-        # device has the same stream_id as the second-to-last device. If so,
-        # then we ignore all devices with that stream_id and only send the
-        # devices with a lower stream_id.
-        #
-        # If when culling the list we end up with no devices afterwards, we
-        # consider the device update to be too large, and simply skip the
-        # stream_id; the rationale being that such a large device list update
-        # is likely an error.
         updates = yield self.db.runInteraction(
             "get_device_updates_by_remote",
             self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
-            limit + 1,
+            limit,
         )
 
         # Return an empty list if there are no updates
@@ -166,14 +158,6 @@ class DeviceWorkerStore(SQLBaseStore):
                     "device_id": verify_key.version,
                 }
 
-        # if we have exceeded the limit, we need to exclude any results with the
-        # same stream_id as the last row.
-        if len(updates) > limit:
-            stream_id_cutoff = updates[-1][2]
-            now_stream_id = stream_id_cutoff - 1
-        else:
-            stream_id_cutoff = None
-
         # Perform the equivalent of a GROUP BY
         #
         # Iterate through the updates list and copy non-duplicate
@@ -192,10 +176,6 @@ class DeviceWorkerStore(SQLBaseStore):
         query_map = {}
         cross_signing_keys_by_user = {}
         for user_id, device_id, update_stream_id, update_context in updates:
-            if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
-                # Stop processing updates
-                break
-
             if (
                 user_id in master_key_by_user
                 and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +198,6 @@ class DeviceWorkerStore(SQLBaseStore):
                 if update_stream_id > previous_update_stream_id:
                     query_map[key] = (update_stream_id, update_context)
 
-        # If we didn't find any updates with a stream_id lower than the cutoff, it
-        # means that there are more than limit updates all of which have the same
-        # steam_id.
-
-        # That should only happen if a client is spamming the server with new
-        # devices, in which case E2E isn't going to work well anyway. We'll just
-        # skip that stream_id and return an empty list, and continue with the next
-        # stream_id next time.
-        if not query_map and not cross_signing_keys_by_user:
-            return stream_id_cutoff, []
-
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
@@ -611,22 +580,33 @@ class DeviceWorkerStore(SQLBaseStore):
         else:
             return set()
 
-    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
-        """Return a list of `(stream_id, user_id, destination)` which is the
-        combined list of changes to devices, and which destinations need to be
-        poked. `destination` may be None if no destinations need to be poked.
+    async def get_all_device_list_changes_for_remotes(
+        self, from_key: int, to_key: int, limit: int,
+    ) -> List[Tuple[int, str]]:
+        """Return a list of `(stream_id, entity)` which is the combined list of
+        changes to devices and which destinations need to be poked. Entity is
+        either a user ID (starting with '@') or a remote destination.
         """
-        # We do a group by here as there can be a large number of duplicate
-        # entries, since we throw away device IDs.
+
+        # This query Does The Right Thing where it'll correctly apply the
+        # bounds to the inner queries.
         sql = """
-            SELECT MAX(stream_id) AS stream_id, user_id, destination
-            FROM device_lists_stream
-            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            SELECT stream_id, entity FROM (
+                SELECT stream_id, user_id AS entity FROM device_lists_stream
+                UNION ALL
+                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+            ) AS e
             WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id, destination
+            LIMIT ?
         """
-        return self.db.execute(
-            "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+        return await self.db.execute(
+            "get_all_device_list_changes_for_remotes",
+            None,
+            sql,
+            from_key,
+            to_key,
+            limit,
         )
 
     @cached(max_entries=10000)
@@ -1021,29 +1001,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
-        with self._device_list_id_gen.get_next() as stream_id:
+        if not device_ids:
+            return
+
+        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+            yield self.db.runInteraction(
+                "add_device_change_to_stream",
+                self._add_device_change_to_stream_txn,
+                user_id,
+                device_ids,
+                stream_ids,
+            )
+
+        if not hosts:
+            return stream_ids[-1]
+
+        context = get_active_span_text_map()
+        with self._device_list_id_gen.get_next_mult(
+            len(hosts) * len(device_ids)
+        ) as stream_ids:
             yield self.db.runInteraction(
-                "add_device_change_to_streams",
-                self._add_device_change_txn,
+                "add_device_outbound_poke_to_stream",
+                self._add_device_outbound_poke_to_stream_txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_id,
+                stream_ids,
+                context,
             )
-        return stream_id
 
-    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
-        now = self._clock.time_msec()
+        return stream_ids[-1]
 
+    def _add_device_change_to_stream_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Collection[str],
+        stream_ids: List[str],
+    ):
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
         )
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_id,
-            )
+
+        min_stream_id = stream_ids[0]
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
@@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
             """,
-            [(user_id, device_id, stream_id) for device_id in device_ids],
+            [(user_id, device_id, min_stream_id) for device_id in device_ids],
         )
 
         self.db.simple_insert_many_txn(
@@ -1060,11 +1060,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="device_lists_stream",
             values=[
                 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
-                for device_id in device_ids
+                for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
 
-        context = get_active_span_text_map()
+    def _add_device_outbound_poke_to_stream_txn(
+        self, txn, user_id, device_ids, hosts, stream_ids, context,
+    ):
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host,
+                stream_ids[-1],
+            )
+
+        now = self._clock.time_msec()
+        next_stream_id = iter(stream_ids)
 
         self.db.simple_insert_many_txn(
             txn,
@@ -1072,7 +1083,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 {
                     "destination": destination,
-                    "stream_id": stream_id,
+                    "stream_id": next(next_stream_id),
                     "user_id": user_id,
                     "device_id": device_id,
                     "sent": False,
@@ -1086,18 +1097,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             ],
         )
 
-    def _prune_old_outbound_device_pokes(self):
+    def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
         """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers. We keep one entry per
-        (destination, user_id) tuple to ensure that the prev_ids remain correct
-        if the server does come back.
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
         """
-        yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+        yesterday = self._clock.time_msec() - prune_age
 
         def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
             select_sql = """
-                SELECT destination, user_id, max(stream_id) as stream_id
-                FROM device_lists_outbound_pokes
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
                 GROUP BY destination, user_id
                 HAVING min(ts) < ? AND count(*) > 1
             """
@@ -1108,14 +1148,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             if not rows:
                 return
 
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
             delete_sql = """
                 DELETE FROM device_lists_outbound_pokes
-                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
             """
-
-            txn.executemany(
-                delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
-            )
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
 
             # Since we've deleted unsent deltas, we need to remove the entry
             # of last successful sent so that the prev_ids are correctly set.
@@ -1125,7 +1180,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             """
             txn.executemany(sql, ((row[0], row[1]) for row in rows))
 
-            logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+            logger.info("Pruned %d device list outbound pokes", count)
 
         return run_as_background_process(
             "prune_old_outbound_device_pokes",
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index c9e7de7d12..e1d1bc3e05 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from collections import namedtuple
+from typing import Optional
 
 from twisted.internet import defer
 
@@ -159,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore):
 
         return room_id
 
-    def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+    def update_aliases_for_room(
+        self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+    ):
+        """Repoint all of the aliases for a given room, to a different room.
+
+        Args:
+            old_room_id:
+            new_room_id:
+            creator: The user to record as the creator of the new mapping.
+                If None, the creator will be left unchanged.
+        """
+
         def _update_aliases_for_room_txn(txn):
-            sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
-            txn.execute(sql, (new_room_id, creator, old_room_id))
+            update_creator_sql = ""
+            sql_params = (new_room_id, old_room_id)
+            if creator:
+                update_creator_sql = ", creator = ?"
+                sql_params = (new_room_id, creator, old_room_id)
+
+            sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % (
+                update_creator_sql,
+            )
+            txn.execute(sql, sql_params)
             self._invalidate_cache_and_stream(
                 txn, self.get_aliases_for_room, (old_room_id,)
             )
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 84594cf0a9..23f4570c4b 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             room_entry["sessions"][row["session_id"]] = {
                 "first_message_index": row["first_message_index"],
                 "forwarded_count": row["forwarded_count"],
-                "is_verified": row["is_verified"],
+                # is_verified must be returned to the client as a boolean
+                "is_verified": bool(row["is_verified"]),
                 "session_data": json.loads(row["session_data"]),
             }
 
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 001a53f9b4..bcf746b7ef 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         return result
 
-    def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+    def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
         """Return a list of changes from the user signature stream to notify remotes.
         Note that the user signature stream represents when a user signs their
         device with their user-signing key, which is not published to other
@@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
         """
         sql = """
-            SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+            SELECT stream_id, from_user_id AS user_id
             FROM user_signature_stream
             WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id
+            ORDER BY stream_id ASC
+            LIMIT ?
         """
         return self.db.execute(
-            "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+            "get_all_user_signature_changes_for_remotes",
+            None,
+            sql,
+            from_key,
+            to_key,
+            limit,
         )
 
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d593ef47b8..e71c23541d 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1267,104 +1267,6 @@ class EventsStore(
         ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
         return ret
 
-    def get_current_backfill_token(self):
-        """The current minimum token that backfilled events have reached"""
-        return -self._backfill_id_gen.get_current_token()
-
-    def get_current_events_token(self):
-        """The current maximum token that events have reached"""
-        return self._stream_id_gen.get_current_token()
-
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_forward_event_rows(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (last_id, upper_bound))
-            new_event_updates.extend(txn)
-
-            return new_event_updates
-
-        return self.db.runInteraction(
-            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
-        )
-
-    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_backfill_event_rows(txn):
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (-last_id, -current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > event_stream_ordering"
-                " AND event_stream_ordering >= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (-last_id, -upper_bound))
-            new_event_updates.extend(txn.fetchall())
-
-            return new_event_updates
-
-        return self.db.runInteraction(
-            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
-        )
-
     @cached(num_args=5, max_entries=10)
     def get_all_new_events(
         self,
@@ -1850,22 +1752,6 @@ class EventsStore(
 
         return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
-    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
-        def get_all_updated_current_state_deltas_txn(txn):
-            sql = """
-                SELECT stream_id, room_id, type, state_key, event_id
-                FROM current_state_delta_stream
-                WHERE ? < stream_id AND stream_id <= ?
-                ORDER BY stream_id ASC LIMIT ?
-            """
-            txn.execute(sql, (from_token, to_token, limit))
-            return txn.fetchall()
-
-        return self.db.runInteraction(
-            "get_all_updated_current_state_deltas",
-            get_all_updated_current_state_deltas_txn,
-        )
-
     def insert_labels_for_event_txn(
         self, txn, event_id, labels, room_id, topological_ordering
     ):
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ca237c6f12..16ea8948b1 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -35,7 +35,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
@@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
         missing_events_ids = [e for e in event_ids if e not in event_entry_map]
 
         if missing_events_ids:
-            log_ctx = LoggingContext.current_context()
+            log_ctx = current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
             # Note that _get_events_from_db is also responsible for turning db rows
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
         complexity_v1 = round(state_events / 500, 2)
 
         return {"v1": complexity_v1}
+
+    def get_current_backfill_token(self):
+        """The current minimum token that backfilled events have reached"""
+        return -self._backfill_id_gen.get_current_token()
+
+    def get_current_events_token(self):
+        """The current maximum token that events have reached"""
+        return self._stream_id_gen.get_current_token()
+
+    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_forward_event_rows(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < event_stream_ordering"
+                " AND event_stream_ordering <= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (last_id, upper_bound))
+            new_event_updates.extend(txn)
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+        )
+
+    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_backfill_event_rows(txn):
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (-last_id, -current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > event_stream_ordering"
+                " AND event_stream_ordering >= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (-last_id, -upper_bound))
+            new_event_updates.extend(txn.fetchall())
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+        )
+
+    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+        def get_all_updated_current_state_deltas_txn(txn):
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id
+                FROM current_state_delta_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_updated_current_state_deltas",
+            get_all_updated_current_state_deltas_txn,
+        )
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 80ca36dedf..cf195f8aa6 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -340,7 +340,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
-    def delete_url_cache(self, media_ids):
+    async def delete_url_cache(self, media_ids):
         if len(media_ids) == 0:
             return
 
@@ -349,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_txn(txn):
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
+        return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     def get_url_cache_media_before(self, before_ts):
         sql = (
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 604c8b7ddd..dab31e0c2d 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
                     "status_msg": state.status_msg,
                     "currently_active": state.currently_active,
                 }
-                for state in presence_states
+                for stream_id, state in zip(stream_orderings, presence_states)
             ],
         )
 
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
             )
             txn.execute(sql + clause, [stream_id] + list(args))
 
-    def get_all_presence_updates(self, last_id, current_id):
+    def get_all_presence_updates(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed([])
 
         def get_all_presence_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, user_id, state, last_active_ts,"
-                " last_federation_update_ts, last_user_sync_ts, status_msg,"
-                " currently_active"
-                " FROM presence_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-            )
-            txn.execute(sql, (last_id, current_id))
+            sql = """
+                SELECT stream_id, user_id, state, last_active_ts,
+                    last_federation_update_ts, last_user_sync_ts,
+                    status_msg,
+                currently_active
+                FROM presence_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
         return self.db.runInteraction(
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 62ac88d9f2..46f9bda773 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -41,6 +41,7 @@ def _load_rules(rawrules, enabled_map):
         rule = dict(rawrule)
         rule["conditions"] = json.loads(rawrule["conditions"])
         rule["actions"] = json.loads(rawrule["actions"])
+        rule["default"] = False
         ruleslist.append(rule)
 
     # We're going to be mutating this a lot, so do a deep copy
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..aaebe427d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
 
         return total_media_quarantined
 
+    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+        def get_all_new_public_rooms(txn):
+            sql = """
+                SELECT stream_id, room_id, visibility, appservice_id, network_id
+                FROM public_room_list_stream
+                WHERE stream_id > ? AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (prev_id, current_id, limit))
+            return txn.fetchall()
+
+        if prev_id == current_id:
+            return defer.succeed([])
+
+        return self.db.runInteraction(
+            "get_all_new_public_rooms", get_all_new_public_rooms
+        )
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
-        def get_all_new_public_rooms(txn):
-            sql = """
-                SELECT stream_id, room_id, visibility, appservice_id, network_id
-                FROM public_room_list_stream
-                WHERE stream_id > ? AND stream_id <= ?
-                ORDER BY stream_id ASC
-                LIMIT ?
-            """
-
-            txn.execute(sql, (prev_id, current_id, limit))
-            return txn.fetchall()
-
-        if prev_id == current_id:
-            return defer.succeed([])
-
-        return self.db.runInteraction(
-            "get_all_new_public_rooms", get_all_new_public_rooms
-        )
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         """Marks the room as blocked. Can be called multiple times.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e61595336c..715c0346dd 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
     LoggingContextOrSentinel,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -483,7 +484,7 @@ class Database(object):
             end = monotonic_time()
             duration = end - start
 
-            LoggingContext.current_context().add_database_transaction(duration)
+            current_context().add_database_transaction(duration)
 
             transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
 
@@ -510,7 +511,7 @@ class Database(object):
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
 
-        if LoggingContext.current_context() == LoggingContext.sentinel:
+        if not current_context():
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
@@ -547,10 +548,8 @@ class Database(object):
         Returns:
             Deferred: The result of func
         """
-        parent_context = (
-            LoggingContext.current_context()
-        )  # type: Optional[LoggingContextOrSentinel]
-        if parent_context == LoggingContext.sentinel:
+        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
+        if not parent_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )