summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-11-08 11:34:03 +0000
committerErik Johnston <erik@matrix.org>2021-11-08 11:34:03 +0000
commit58864265d15809e824bb6cb02e4c3a9f533e4aae (patch)
tree7578438369ce85ce6c189acee879578ad77ec007 /synapse/storage/databases/main
parentAdd spans for sync (diff)
parentBlacklist new sytest validation test (#11270) (diff)
downloadsynapse-58864265d15809e824bb6cb02e4c3a9f533e4aae.tar.xz
Merge remote-tracking branch 'origin/develop' into erikj/slow_sync_diag
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py89
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py56
-rw-r--r--synapse/storage/databases/main/group_server.py17
-rw-r--r--synapse/storage/databases/main/lock.py31
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/room.py29
-rw-r--r--synapse/storage/databases/main/roommember.py8
9 files changed, 190 insertions, 53 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 25e9c1efe1..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -561,6 +561,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
     REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+    REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
 
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -581,6 +582,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self._remove_deleted_devices_from_device_inbox,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            self.REMOVE_HIDDEN_DEVICES,
+            self._remove_hidden_devices_from_device_inbox,
+        )
+
     async def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
             txn = conn.cursor()
@@ -676,6 +682,89 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 
         return number_deleted
 
+    async def _remove_hidden_devices_from_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """A background update that deletes all device_inboxes for hidden devices.
+
+        This should only need to be run once (when users upgrade to v1.47.0)
+
+        Args:
+            progress: JsonDict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        def _remove_hidden_devices_from_device_inbox_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            """stream_id is not unique
+            we need to use an inclusive `stream_id >= ?` clause,
+            since we might not have deleted all hidden device messages for the stream_id
+            returned from the previous query
+
+            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+            to avoid problems of deleting a large number of rows all at once
+            due to a single device having lots of device messages.
+            """
+
+            last_stream_id = progress.get("stream_id", 0)
+
+            sql = """
+                SELECT device_id, user_id, stream_id
+                FROM device_inbox
+                WHERE
+                    stream_id >= ?
+                    AND (device_id, user_id) IN (
+                        SELECT device_id, user_id FROM devices WHERE hidden = ?
+                    )
+                ORDER BY stream_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_stream_id, True, batch_size))
+            rows = txn.fetchall()
+
+            num_deleted = 0
+            for row in rows:
+                num_deleted += self.db_pool.simple_delete_txn(
+                    txn,
+                    "device_inbox",
+                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+                )
+
+            if rows:
+                # We don't just save the `stream_id` in progress as
+                # otherwise it can happen in large deployments that
+                # no change of status is visible in the log file, as
+                # it may be that the stream_id does not change in several runs
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    self.REMOVE_HIDDEN_DEVICES,
+                    {
+                        "device_id": rows[-1][0],
+                        "user_id": rows[-1][1],
+                        "stream_id": rows[-1][2],
+                    },
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_hidden_devices_from_device_inbox",
+            _remove_hidden_devices_from_device_inbox_txn,
+        )
+
+        # The task is finished when no more lines are deleted.
+        if not number_deleted:
+            await self.db_pool.updates._end_background_update(
+                self.REMOVE_HIDDEN_DEVICES
+            )
+
+        return number_deleted
+
 
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b15cd030e0..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
             user_ids: the users who were signed
 
         Returns:
-            THe new stream ID.
+            The new stream ID.
         """
 
         async with self._device_list_id_gen.get_next() as stream_id:
@@ -1322,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
     async def add_device_change_to_streams(
         self, user_id: str, device_ids: Collection[str], hosts: List[str]
-    ):
+    ) -> int:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
 )
@@ -494,7 +495,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, List[str]],
+        event_to_auth_chain: Dict[str, Sequence[str]],
     ) -> None:
         """Calculate the chain cover index for the given events.
 
@@ -786,7 +787,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, List[str]],
+        event_to_auth_chain: Dict[str, Sequence[str]],
         events_to_calc_chain_id_for: Set[str],
         chain_map: Dict[str, Tuple[int, int]],
     ) -> Dict[str, Tuple[int, int]]:
@@ -1794,7 +1795,7 @@ class PersistEventsStore:
         )
 
         # Insert an edge for every prev_event connection
-        for prev_event_id in event.prev_events:
+        for prev_event_id in event.prev_event_ids():
             self.db_pool.simple_insert_txn(
                 txn,
                 table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ae37901be9..c6bf316d5b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -28,6 +28,7 @@ from typing import (
 
 import attr
 from constantly import NamedConstant, Names
+from prometheus_client import Gauge
 from typing_extensions import Literal
 
 from twisted.internet import defer
@@ -81,6 +82,12 @@ EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 
 
+event_fetch_ongoing_gauge = Gauge(
+    "synapse_event_fetch_ongoing",
+    "The number of event fetchers that are running",
+)
+
+
 @attr.s(slots=True, auto_attribs=True)
 class _EventCacheEntry:
     event: EventBase
@@ -222,6 +229,7 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
+        event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
@@ -732,28 +740,31 @@ class EventsWorkerStore(SQLBaseStore):
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        i = 0
-        while True:
-            with self._event_fetch_lock:
-                event_list = self._event_fetch_list
-                self._event_fetch_list = []
-
-                if not event_list:
-                    single_threaded = self.database_engine.single_threaded
-                    if (
-                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                        or single_threaded
-                        or i > EVENT_QUEUE_ITERATIONS
-                    ):
-                        self._event_fetch_ongoing -= 1
-                        return
-                    else:
-                        self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                        i += 1
-                        continue
-                i = 0
-
-            self._fetch_event_list(conn, event_list)
+        try:
+            i = 0
+            while True:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        single_threaded = self.database_engine.single_threaded
+                        if (
+                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                            or single_threaded
+                            or i > EVENT_QUEUE_ITERATIONS
+                        ):
+                            break
+                        else:
+                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                            i += 1
+                            continue
+                    i = 0
+
+                self._fetch_event_list(conn, event_list)
+        finally:
+            self._event_fetch_ongoing -= 1
+            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
     def _fetch_event_list(
         self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -977,6 +988,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
                 self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
                 should_start = True
             else:
                 should_start = False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # The category ID for the "default" category. We don't store as null in the
 # database to avoid the fun of null != null
 _DEFAULT_CATEGORY_ID = ""
@@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict):
 
 
 class GroupServerWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        database.updates.register_background_index_update(
+            update_name="local_group_updates_index",
+            index_name="local_group_updates_stream_id_index",
+            table="local_group_updates",
+            columns=("stream_id",),
+            unique=True,
+        )
+        super().__init__(database, db_conn, hs)
+
     async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
             table="groups",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d1dff660b..3d0df0cbd4 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -14,6 +14,7 @@
 import logging
 from types import TracebackType
 from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
 
@@ -61,7 +62,7 @@ class LockStore(SQLBaseStore):
 
         # A map from `(lock_name, lock_key)` to the token of any locks that we
         # think we currently hold.
-        self._live_tokens: Dict[Tuple[str, str], str] = {}
+        self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
 
         # When we shut down we want to remove the locks. Technically this can
         # lead to a race, as we may drop the lock while we are still processing.
@@ -80,10 +81,10 @@ class LockStore(SQLBaseStore):
 
         # We need to take a copy of the tokens dict as dropping the locks will
         # cause the dictionary to change.
-        tokens = dict(self._live_tokens)
+        locks = dict(self._live_tokens)
 
-        for (lock_name, lock_key), token in tokens.items():
-            await self._drop_lock(lock_name, lock_key, token)
+        for lock in locks.values():
+            await lock.release()
 
         logger.info("Dropped locks due to shutdown")
 
@@ -93,6 +94,11 @@ class LockStore(SQLBaseStore):
         used (otherwise the lock will leak).
         """
 
+        # Check if this process has taken out a lock and if it's still valid.
+        lock = self._live_tokens.get((lock_name, lock_key))
+        if lock and await lock.is_still_valid():
+            return None
+
         now = self._clock.time_msec()
         token = random_string(6)
 
@@ -100,7 +106,9 @@ class LockStore(SQLBaseStore):
 
             def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
                 # We take out the lock if either a) there is no row for the lock
-                # already or b) the existing row has timed out.
+                # already, b) the existing row has timed out, or c) the row is
+                # for this instance (which means the process got killed and
+                # restarted)
                 sql = """
                     INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
                     VALUES (?, ?, ?, ?, ?)
@@ -112,6 +120,7 @@ class LockStore(SQLBaseStore):
                             last_renewed_ts = EXCLUDED.last_renewed_ts
                         WHERE
                             worker_locks.last_renewed_ts < ?
+                            OR worker_locks.instance_name = EXCLUDED.instance_name
                 """
                 txn.execute(
                     sql,
@@ -148,11 +157,11 @@ class LockStore(SQLBaseStore):
                     WHERE
                         lock_name = ?
                         AND lock_key = ?
-                        AND last_renewed_ts < ?
+                        AND (last_renewed_ts < ? OR instance_name = ?)
                 """
                 txn.execute(
                     sql,
-                    (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+                    (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
                 )
 
                 inserted = self.db_pool.simple_upsert_txn_emulated(
@@ -179,9 +188,7 @@ class LockStore(SQLBaseStore):
             if not did_lock:
                 return None
 
-        self._live_tokens[(lock_name, lock_key)] = token
-
-        return Lock(
+        lock = Lock(
             self._reactor,
             self._clock,
             self,
@@ -190,6 +197,10 @@ class LockStore(SQLBaseStore):
             token=token,
         )
 
+        self._live_tokens[(lock_name, lock_key)] = lock
+
+        return lock
+
     async def _is_lock_still_valid(
         self, lock_name: str, lock_key: str, token: str
     ) -> bool:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             prefilled_cache=presence_cache_prefill,
         )
 
-    async def update_presence(self, presence_states):
+    async def update_presence(self, presence_states) -> Tuple[int, int]:
         assert self._can_persist_presence
 
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
             limit: maximum amount of rooms to retrieve
             order_by: the sort order of the returned list
             reverse_order: whether to reverse the room list
-            search_term: a string to filter room names by
+            search_term: a string to filter room names,
+                canonical alias and room ids by.
+                Room ID must match exactly. Canonical alias must match a substring of the local part.
         Returns:
             A list of room dicts and an integer representing the total number of
             rooms that exist given this query
         """
         # Filter room names by a string
         where_statement = ""
+        search_pattern = []
         if search_term:
-            where_statement = "WHERE LOWER(state.name) LIKE ?"
+            where_statement = """
+                WHERE LOWER(state.name) LIKE ?
+                OR LOWER(state.canonical_alias) LIKE ?
+                OR state.room_id = ?
+            """
 
             # Our postgres db driver converts ? -> %s in SQL strings as that's the
             # placeholder for postgres.
             # HOWEVER, if you put a % into your SQL then everything goes wibbly.
             # To get around this, we're going to surround search_term with %'s
             # before giving it to the database in python instead
-            search_term = "%" + search_term.lower() + "%"
+            search_pattern = [
+                "%" + search_term.lower() + "%",
+                "#%" + search_term.lower() + "%:%",
+                search_term,
+            ]
 
         # Set ordering
         if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
         def _get_rooms_paginate_txn(txn):
-            # Execute the data query
-            sql_values = (limit, start)
-            if search_term:
-                # Add the search term into the WHERE clause
-                sql_values = (search_term,) + sql_values
-            txn.execute(info_sql, sql_values)
+            # Add the search term into the WHERE clause
+            # and execute the data query
+            txn.execute(info_sql, search_pattern + [limit, start])
 
             # Refactor room query data into a structured dictionary
             rooms = []
@@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
             # Execute the count query
 
             # Add the search term into the WHERE clause if present
-            sql_values = (search_term,) if search_term else ()
-            txn.execute(count_sql, sql_values)
+            txn.execute(count_sql, search_pattern)
 
             room_count = txn.fetchone()
             return rooms, room_count[0]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
-    ):
+    ) -> Dict[str, ProfileInfo]:
         state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
 
-    async def get_joined_users_from_state(self, room_id, state_entry):
+    async def get_joined_users_from_state(
+        self, room_id, state_entry
+    ) -> Dict[str, ProfileInfo]:
         state_group = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         cache_context,
         event=None,
         context=None,
-    ):
+    ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.