summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-11-17 14:19:27 +0000
committerDavid Robertson <davidr@element.io>2021-11-17 14:19:27 +0000
commit077b74929f8f412395d1156e1b97eb16701059fa (patch)
tree25ecb8e93ec3ba275596bfb31efa45018645d47b /synapse/storage
parentCorrect target of link to the modules page from the Password Auth Providers p... (diff)
parent1.47.0 (diff)
downloadsynapse-077b74929f8f412395d1156e1b97eb16701059fa.tar.xz
Merge remote-tracking branch 'origin/release-v1.47'
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py65
-rw-r--r--synapse/storage/database.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py189
-rw-r--r--synapse/storage/databases/main/devices.py39
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py56
-rw-r--r--synapse/storage/databases/main/group_server.py17
-rw-r--r--synapse/storage/databases/main/lock.py31
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/relations.py38
-rw-r--r--synapse/storage/databases/main/room.py29
-rw-r--r--synapse/storage/databases/main/room_batch.py2
-rw-r--r--synapse/storage/databases/main/roommember.py8
-rw-r--r--synapse/storage/prepare_database.py27
-rw-r--r--synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql22
-rw-r--r--synapse/storage/schema/main/delta/65/04_local_group_updates.sql18
-rw-r--r--synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql34
18 files changed, 480 insertions, 117 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 82b31d24f1..b9a8ca997e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -100,29 +100,58 @@ class BackgroundUpdater:
         ] = {}
         self._all_done = False
 
+        # Whether we're currently running updates
+        self._running = False
+
+        # Whether background updates are enabled. This allows us to
+        # enable/disable background updates via the admin API.
+        self.enabled = True
+
+    def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
+        """Returns the current background update, if any."""
+
+        update_name = self._current_background_update
+        if not update_name:
+            return None
+
+        perf = self._background_update_performance.get(update_name)
+        if not perf:
+            perf = BackgroundUpdatePerformance(update_name)
+
+        return perf
+
     def start_doing_background_updates(self) -> None:
-        run_as_background_process("background_updates", self.run_background_updates)
+        if self.enabled:
+            run_as_background_process("background_updates", self.run_background_updates)
 
     async def run_background_updates(self, sleep: bool = True) -> None:
-        logger.info("Starting background schema updates")
-        while True:
-            if sleep:
-                await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+        if self._running or not self.enabled:
+            return
 
-            try:
-                result = await self.do_next_background_update(
-                    self.BACKGROUND_UPDATE_DURATION_MS
-                )
-            except Exception:
-                logger.exception("Error doing update")
-            else:
-                if result:
-                    logger.info(
-                        "No more background updates to do."
-                        " Unscheduling background update task."
+        self._running = True
+
+        try:
+            logger.info("Starting background schema updates")
+            while self.enabled:
+                if sleep:
+                    await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+
+                try:
+                    result = await self.do_next_background_update(
+                        self.BACKGROUND_UPDATE_DURATION_MS
                     )
-                    self._all_done = True
-                    return None
+                except Exception:
+                    logger.exception("Error doing update")
+                else:
+                    if result:
+                        logger.info(
+                            "No more background updates to do."
+                            " Unscheduling background update task."
+                        )
+                        self._all_done = True
+                        return None
+        finally:
+            self._running = False
 
     async def has_completed_background_updates(self) -> bool:
         """Check if all the background updates have completed
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index fa4e89d35c..d4cab69ebf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -48,6 +48,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.metrics import register_threadpool
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
@@ -104,13 +105,17 @@ def make_pool(
                 LoggingDatabaseConnection(conn, engine, "on_new_connection")
             )
 
-    return adbapi.ConnectionPool(
+    connection_pool = adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
         cp_openfun=_on_new_connection,
         **db_args,
     )
 
+    register_threadpool(f"database-{db_config.name}", connection_pool.threadpool)
+
+    return connection_pool
+
 
 def make_conn(
     db_config: DatabaseConnectionConfig,
@@ -441,6 +446,10 @@ class DatabasePool:
                 self._check_safe_to_upsert,
             )
 
+    def name(self) -> str:
+        "Return the name of this database"
+        return self._database_config.name
+
     def is_running(self) -> bool:
         """Is the database pool currently running"""
         return self._db_pool.running
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8143168107..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -488,10 +489,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
+                # We exclude hidden devices (such as cross-signing keys) here as they are
+                # not expected to receive to-device messages.
                 devices = self.db_pool.simple_select_onecol_txn(
                     txn,
                     table="devices",
-                    keyvalues={"user_id": user_id},
+                    keyvalues={"user_id": user_id, "hidden": False},
                     retcol="device_id",
                 )
 
@@ -504,10 +507,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 if not devices:
                     continue
 
+                # We exclude hidden devices (such as cross-signing keys) here as they are
+                # not expected to receive to-device messages.
                 rows = self.db_pool.simple_select_many_txn(
                     txn,
                     table="devices",
-                    keyvalues={"user_id": user_id},
+                    keyvalues={"user_id": user_id, "hidden": False},
                     column="device_id",
                     iterable=devices,
                     retcols=("device_id",),
@@ -555,6 +560,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+    REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+    REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
 
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -570,6 +577,16 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            self.REMOVE_DELETED_DEVICES,
+            self._remove_deleted_devices_from_device_inbox,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            self.REMOVE_HIDDEN_DEVICES,
+            self._remove_hidden_devices_from_device_inbox,
+        )
+
     async def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
             txn = conn.cursor()
@@ -582,6 +599,172 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
+    async def _remove_deleted_devices_from_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """A background update that deletes all device_inboxes for deleted devices.
+
+        This should only need to be run once (when users upgrade to v1.47.0)
+
+        Args:
+            progress: JsonDict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        def _remove_deleted_devices_from_device_inbox_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            """stream_id is not unique
+            we need to use an inclusive `stream_id >= ?` clause,
+            since we might not have deleted all dead device messages for the stream_id
+            returned from the previous query
+
+            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+            to avoid problems of deleting a large number of rows all at once
+            due to a single device having lots of device messages.
+            """
+
+            last_stream_id = progress.get("stream_id", 0)
+
+            sql = """
+                SELECT device_id, user_id, stream_id
+                FROM device_inbox
+                WHERE
+                    stream_id >= ?
+                    AND (device_id, user_id) NOT IN (
+                        SELECT device_id, user_id FROM devices
+                    )
+                ORDER BY stream_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_stream_id, batch_size))
+            rows = txn.fetchall()
+
+            num_deleted = 0
+            for row in rows:
+                num_deleted += self.db_pool.simple_delete_txn(
+                    txn,
+                    "device_inbox",
+                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+                )
+
+            if rows:
+                # send more than stream_id to progress
+                # otherwise it can happen in large deployments that
+                # no change of status is visible in the log file
+                # it may be that the stream_id does not change in several runs
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    self.REMOVE_DELETED_DEVICES,
+                    {
+                        "device_id": rows[-1][0],
+                        "user_id": rows[-1][1],
+                        "stream_id": rows[-1][2],
+                    },
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_deleted_devices_from_device_inbox",
+            _remove_deleted_devices_from_device_inbox_txn,
+        )
+
+        # The task is finished when no more lines are deleted.
+        if not number_deleted:
+            await self.db_pool.updates._end_background_update(
+                self.REMOVE_DELETED_DEVICES
+            )
+
+        return number_deleted
+
+    async def _remove_hidden_devices_from_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """A background update that deletes all device_inboxes for hidden devices.
+
+        This should only need to be run once (when users upgrade to v1.47.0)
+
+        Args:
+            progress: JsonDict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        def _remove_hidden_devices_from_device_inbox_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            """stream_id is not unique
+            we need to use an inclusive `stream_id >= ?` clause,
+            since we might not have deleted all hidden device messages for the stream_id
+            returned from the previous query
+
+            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+            to avoid problems of deleting a large number of rows all at once
+            due to a single device having lots of device messages.
+            """
+
+            last_stream_id = progress.get("stream_id", 0)
+
+            sql = """
+                SELECT device_id, user_id, stream_id
+                FROM device_inbox
+                WHERE
+                    stream_id >= ?
+                    AND (device_id, user_id) IN (
+                        SELECT device_id, user_id FROM devices WHERE hidden = ?
+                    )
+                ORDER BY stream_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_stream_id, True, batch_size))
+            rows = txn.fetchall()
+
+            num_deleted = 0
+            for row in rows:
+                num_deleted += self.db_pool.simple_delete_txn(
+                    txn,
+                    "device_inbox",
+                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+                )
+
+            if rows:
+                # We don't just save the `stream_id` in progress as
+                # otherwise it can happen in large deployments that
+                # no change of status is visible in the log file, as
+                # it may be that the stream_id does not change in several runs
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    self.REMOVE_HIDDEN_DEVICES,
+                    {
+                        "device_id": rows[-1][0],
+                        "user_id": rows[-1][1],
+                        "stream_id": rows[-1][2],
+                    },
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_hidden_devices_from_device_inbox",
+            _remove_hidden_devices_from_device_inbox_txn,
+        )
+
+        # The task is finished when no more lines are deleted.
+        if not number_deleted:
+            await self.db_pool.updates._end_background_update(
+                self.REMOVE_HIDDEN_DEVICES
+            )
+
+        return number_deleted
+
 
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a01bf2c5b7..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
             user_ids: the users who were signed
 
         Returns:
-            THe new stream ID.
+            The new stream ID.
         """
 
         async with self._device_list_id_gen.get_next() as stream_id:
@@ -1134,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             raise StoreError(500, "Problem storing device.")
 
     async def delete_device(self, user_id: str, device_id: str) -> None:
-        """Delete a device.
+        """Delete a device and its device_inbox.
 
         Args:
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to delete
         """
-        await self.db_pool.simple_delete_one(
-            table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
-            desc="delete_device",
-        )
 
-        self.device_id_exists_cache.invalidate((user_id, device_id))
+        await self.delete_devices(user_id, [device_id])
 
     async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
         """Deletes several devices.
@@ -1155,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user which owns the devices
             device_ids: The IDs of the devices to delete
         """
-        await self.db_pool.simple_delete_many(
-            table="devices",
-            column="device_id",
-            iterable=device_ids,
-            keyvalues={"user_id": user_id, "hidden": False},
-            desc="delete_devices",
-        )
+
+        def _delete_devices_txn(txn: LoggingTransaction) -> None:
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="devices",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id, "hidden": False},
+            )
+
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_inbox",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
+        await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
 
@@ -1315,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
     async def add_device_change_to_streams(
         self, user_id: str, device_ids: Collection[str], hosts: List[str]
-    ):
+    ) -> int:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
 )
@@ -494,7 +495,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, List[str]],
+        event_to_auth_chain: Dict[str, Sequence[str]],
     ) -> None:
         """Calculate the chain cover index for the given events.
 
@@ -786,7 +787,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, List[str]],
+        event_to_auth_chain: Dict[str, Sequence[str]],
         events_to_calc_chain_id_for: Set[str],
         chain_map: Dict[str, Tuple[int, int]],
     ) -> Dict[str, Tuple[int, int]]:
@@ -1794,7 +1795,7 @@ class PersistEventsStore:
         )
 
         # Insert an edge for every prev_event connection
-        for prev_event_id in event.prev_events:
+        for prev_event_id in event.prev_event_ids():
             self.db_pool.simple_insert_txn(
                 txn,
                 table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ae37901be9..c6bf316d5b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -28,6 +28,7 @@ from typing import (
 
 import attr
 from constantly import NamedConstant, Names
+from prometheus_client import Gauge
 from typing_extensions import Literal
 
 from twisted.internet import defer
@@ -81,6 +82,12 @@ EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 
 
+event_fetch_ongoing_gauge = Gauge(
+    "synapse_event_fetch_ongoing",
+    "The number of event fetchers that are running",
+)
+
+
 @attr.s(slots=True, auto_attribs=True)
 class _EventCacheEntry:
     event: EventBase
@@ -222,6 +229,7 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
+        event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
@@ -732,28 +740,31 @@ class EventsWorkerStore(SQLBaseStore):
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        i = 0
-        while True:
-            with self._event_fetch_lock:
-                event_list = self._event_fetch_list
-                self._event_fetch_list = []
-
-                if not event_list:
-                    single_threaded = self.database_engine.single_threaded
-                    if (
-                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                        or single_threaded
-                        or i > EVENT_QUEUE_ITERATIONS
-                    ):
-                        self._event_fetch_ongoing -= 1
-                        return
-                    else:
-                        self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                        i += 1
-                        continue
-                i = 0
-
-            self._fetch_event_list(conn, event_list)
+        try:
+            i = 0
+            while True:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        single_threaded = self.database_engine.single_threaded
+                        if (
+                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                            or single_threaded
+                            or i > EVENT_QUEUE_ITERATIONS
+                        ):
+                            break
+                        else:
+                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                            i += 1
+                            continue
+                    i = 0
+
+                self._fetch_event_list(conn, event_list)
+        finally:
+            self._event_fetch_ongoing -= 1
+            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
     def _fetch_event_list(
         self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -977,6 +988,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
                 self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
                 should_start = True
             else:
                 should_start = False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # The category ID for the "default" category. We don't store as null in the
 # database to avoid the fun of null != null
 _DEFAULT_CATEGORY_ID = ""
@@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict):
 
 
 class GroupServerWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        database.updates.register_background_index_update(
+            update_name="local_group_updates_index",
+            index_name="local_group_updates_stream_id_index",
+            table="local_group_updates",
+            columns=("stream_id",),
+            unique=True,
+        )
+        super().__init__(database, db_conn, hs)
+
     async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
             table="groups",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d1dff660b..3d0df0cbd4 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -14,6 +14,7 @@
 import logging
 from types import TracebackType
 from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
 
@@ -61,7 +62,7 @@ class LockStore(SQLBaseStore):
 
         # A map from `(lock_name, lock_key)` to the token of any locks that we
         # think we currently hold.
-        self._live_tokens: Dict[Tuple[str, str], str] = {}
+        self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
 
         # When we shut down we want to remove the locks. Technically this can
         # lead to a race, as we may drop the lock while we are still processing.
@@ -80,10 +81,10 @@ class LockStore(SQLBaseStore):
 
         # We need to take a copy of the tokens dict as dropping the locks will
         # cause the dictionary to change.
-        tokens = dict(self._live_tokens)
+        locks = dict(self._live_tokens)
 
-        for (lock_name, lock_key), token in tokens.items():
-            await self._drop_lock(lock_name, lock_key, token)
+        for lock in locks.values():
+            await lock.release()
 
         logger.info("Dropped locks due to shutdown")
 
@@ -93,6 +94,11 @@ class LockStore(SQLBaseStore):
         used (otherwise the lock will leak).
         """
 
+        # Check if this process has taken out a lock and if it's still valid.
+        lock = self._live_tokens.get((lock_name, lock_key))
+        if lock and await lock.is_still_valid():
+            return None
+
         now = self._clock.time_msec()
         token = random_string(6)
 
@@ -100,7 +106,9 @@ class LockStore(SQLBaseStore):
 
             def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
                 # We take out the lock if either a) there is no row for the lock
-                # already or b) the existing row has timed out.
+                # already, b) the existing row has timed out, or c) the row is
+                # for this instance (which means the process got killed and
+                # restarted)
                 sql = """
                     INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
                     VALUES (?, ?, ?, ?, ?)
@@ -112,6 +120,7 @@ class LockStore(SQLBaseStore):
                             last_renewed_ts = EXCLUDED.last_renewed_ts
                         WHERE
                             worker_locks.last_renewed_ts < ?
+                            OR worker_locks.instance_name = EXCLUDED.instance_name
                 """
                 txn.execute(
                     sql,
@@ -148,11 +157,11 @@ class LockStore(SQLBaseStore):
                     WHERE
                         lock_name = ?
                         AND lock_key = ?
-                        AND last_renewed_ts < ?
+                        AND (last_renewed_ts < ? OR instance_name = ?)
                 """
                 txn.execute(
                     sql,
-                    (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+                    (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
                 )
 
                 inserted = self.db_pool.simple_upsert_txn_emulated(
@@ -179,9 +188,7 @@ class LockStore(SQLBaseStore):
             if not did_lock:
                 return None
 
-        self._live_tokens[(lock_name, lock_key)] = token
-
-        return Lock(
+        lock = Lock(
             self._reactor,
             self._clock,
             self,
@@ -190,6 +197,10 @@ class LockStore(SQLBaseStore):
             token=token,
         )
 
+        self._live_tokens[(lock_name, lock_key)] = lock
+
+        return lock
+
     async def _is_lock_still_valid(
         self, lock_name: str, lock_key: str, token: str
     ) -> bool:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             prefilled_cache=presence_cache_prefill,
         )
 
-    async def update_presence(self, presence_states):
+    async def update_presence(self, presence_states) -> Tuple[int, int]:
         assert self._can_persist_presence
 
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index ba7075caa5..dd8e27e226 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def update_remote_profile_cache(
-        self, user_id: str, displayname: str, avatar_url: str
+        self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
     ) -> int:
         return await self.db_pool.simple_update(
             table="remote_profile_cache",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 40760fbd1b..53576ad52f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@
 # limitations under the License.
 
 import logging
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
 
 import attr
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
         """
 
         where_clause = ["relates_to_id = ?"]
-        where_args = [event_id]
+        where_args: List[Union[str, int]] = [event_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
+            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
+            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
             engine=self.database_engine,
         )
 
@@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
             order,
         )
 
-        def _get_recent_references_for_event_txn(txn):
+        def _get_recent_references_for_event_txn(
+            txn: LoggingTransaction,
+        ) -> PaginationChunk:
             txn.execute(sql, where_args + [limit + 1])
 
             last_topo_id = None
@@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
         """
 
         where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args = [event_id, RelationTypes.ANNOTATION]
+        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
 
         if event_type:
             where_clause.append("type = ?")
@@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
         having_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("COUNT(*)", "MAX(stream_ordering)"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
+            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
+            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
             engine=self.database_engine,
         )
 
@@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
             having_clause=having_clause,
         )
 
-        def _get_aggregation_groups_for_event_txn(txn):
+        def _get_aggregation_groups_for_event_txn(
+            txn: LoggingTransaction,
+        ) -> PaginationChunk:
             txn.execute(sql, where_args + [limit + 1])
 
             next_batch = None
@@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
             LIMIT 1
         """
 
-        def _get_applicable_edit_txn(txn):
+        def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
             txn.execute(sql, (event_id, RelationTypes.REPLACE))
             row = txn.fetchone()
             if row:
                 return row[0]
+            return None
 
         edit_id = await self.db_pool.runInteraction(
             "get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore):
         if not edit_id:
             return None
 
-        return await self.get_event(edit_id, allow_none=True)
+        return await self.get_event(edit_id, allow_none=True)  # type: ignore[attr-defined]
 
     @cached()
     async def get_thread_summary(
@@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore):
             The number of items in the thread and the most recent response, if any.
         """
 
-        def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+        def _get_thread_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, Optional[str]]:
             # Fetch the count of threaded events and the latest event ID.
             # TODO Should this only allow m.room.message events.
             sql = """
@@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND relation_type = ?
             """
             txn.execute(sql, (event_id, RelationTypes.THREAD))
-            count = txn.fetchone()[0]
+            count = txn.fetchone()[0]  # type: ignore[index]
 
             return count, latest_event_id
 
@@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_event = None
         if latest_event_id:
-            latest_event = await self.get_event(latest_event_id, allow_none=True)
+            latest_event = await self.get_event(latest_event_id, allow_none=True)  # type: ignore[attr-defined]
 
         return count, latest_event
 
@@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
             LIMIT 1;
         """
 
-        def _get_if_user_has_annotated_event(txn):
+        def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
             txn.execute(
                 sql,
                 (
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
             limit: maximum amount of rooms to retrieve
             order_by: the sort order of the returned list
             reverse_order: whether to reverse the room list
-            search_term: a string to filter room names by
+            search_term: a string to filter room names,
+                canonical alias and room ids by.
+                Room ID must match exactly. Canonical alias must match a substring of the local part.
         Returns:
             A list of room dicts and an integer representing the total number of
             rooms that exist given this query
         """
         # Filter room names by a string
         where_statement = ""
+        search_pattern = []
         if search_term:
-            where_statement = "WHERE LOWER(state.name) LIKE ?"
+            where_statement = """
+                WHERE LOWER(state.name) LIKE ?
+                OR LOWER(state.canonical_alias) LIKE ?
+                OR state.room_id = ?
+            """
 
             # Our postgres db driver converts ? -> %s in SQL strings as that's the
             # placeholder for postgres.
             # HOWEVER, if you put a % into your SQL then everything goes wibbly.
             # To get around this, we're going to surround search_term with %'s
             # before giving it to the database in python instead
-            search_term = "%" + search_term.lower() + "%"
+            search_pattern = [
+                "%" + search_term.lower() + "%",
+                "#%" + search_term.lower() + "%:%",
+                search_term,
+            ]
 
         # Set ordering
         if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
         def _get_rooms_paginate_txn(txn):
-            # Execute the data query
-            sql_values = (limit, start)
-            if search_term:
-                # Add the search term into the WHERE clause
-                sql_values = (search_term,) + sql_values
-            txn.execute(info_sql, sql_values)
+            # Add the search term into the WHERE clause
+            # and execute the data query
+            txn.execute(info_sql, search_pattern + [limit, start])
 
             # Refactor room query data into a structured dictionary
             rooms = []
@@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
             # Execute the count query
 
             # Add the search term into the WHERE clause if present
-            sql_values = (search_term,) if search_term else ()
-            txn.execute(count_sql, sql_values)
+            txn.execute(count_sql, search_pattern)
 
             room_count = txn.fetchone()
             return rooms, room_count[0]
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index dcbce8fdcf..97b2618437 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -18,7 +18,7 @@ from synapse.storage._base import SQLBaseStore
 
 
 class RoomBatchStore(SQLBaseStore):
-    async def get_insertion_event_by_batch_id(
+    async def get_insertion_event_id_by_batch_id(
         self, room_id: str, batch_id: str
     ) -> Optional[str]:
         """Retrieve a insertion event ID.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
-    ):
+    ) -> Dict[str, ProfileInfo]:
         state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
 
-    async def get_joined_users_from_state(self, room_id, state_entry):
+    async def get_joined_users_from_state(
+        self, room_id, state_entry
+    ) -> Dict[str, ProfileInfo]:
         state_group = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         cache_context,
         event=None,
         context=None,
-    ):
+    ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1629d2a53c..e45adfcb55 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -131,17 +131,9 @@ def prepare_database(
                     "config==None in prepare_database, but database is not empty"
                 )
 
-            # if it's a worker app, refuse to upgrade the database, to avoid multiple
-            # workers doing it at once.
-            if (
-                config.worker.worker_app is not None
-                and version_info.current_version != SCHEMA_VERSION
-            ):
-                raise UpgradeDatabaseException(
-                    OUTDATED_SCHEMA_ON_WORKER_ERROR
-                    % (SCHEMA_VERSION, version_info.current_version)
-                )
-
+            # This should be run on all processes, master or worker. The master will
+            # apply the deltas, while workers will check if any outstanding deltas
+            # exist and raise an PrepareDatabaseException if they do.
             _upgrade_existing_database(
                 cur,
                 version_info,
@@ -149,6 +141,7 @@ def prepare_database(
                 config,
                 databases=databases,
             )
+
         else:
             logger.info("%r: Initialising new database", databases)
 
@@ -357,6 +350,18 @@ def _upgrade_existing_database(
 
     is_worker = config and config.worker.worker_app is not None
 
+    # If the schema version needs to be updated, and we are on a worker, we immediately
+    # know to bail out as workers cannot update the database schema. Only one process
+    # must update the database at the time, therefore we delegate this task to the master.
+    if is_worker and current_schema_state.current_version < SCHEMA_VERSION:
+        # If the DB is on an older version than we expect then we refuse
+        # to start the worker (as the main process needs to run first to
+        # update the schema).
+        raise UpgradeDatabaseException(
+            OUTDATED_SCHEMA_ON_WORKER_ERROR
+            % (SCHEMA_VERSION, current_schema_state.current_version)
+        )
+
     if (
         current_schema_state.compat_version is not None
         and current_schema_state.compat_version > SCHEMA_VERSION
diff --git a/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..7b3592dcf0
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
@@ -0,0 +1,22 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- because a device was hidden using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6503, 'remove_hidden_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/04_local_group_updates.sql b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
new file mode 100644
index 0000000000..a178abfe12
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Check index on `local_group_updates.stream_id`.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6504, 'local_group_updates_index', '{}');
diff --git a/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..82f6408b36
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- when a device was deleted using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+-- Remove any existing instances of this job running. It's OK to stop and restart this job,
+-- as it's just deleting entries from a table - no progress will be lost.
+--
+-- This is necessary due a similar migration running the job accidentally
+-- being included in schema version 64 during v1.47.0rc1,rc2. If a
+-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64),
+-- then they would have started running this background update already.
+-- If that update was still running, then simply inserting it again would
+-- cause an SQL failure. So we effectively do an "upsert" here instead.
+
+DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6506, 'remove_deleted_devices_from_device_inbox', '{}');