summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py189
-rw-r--r--synapse/storage/databases/main/devices.py39
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/relations.py38
-rw-r--r--synapse/storage/databases/main/room.py29
-rw-r--r--synapse/storage/databases/main/roommember.py8
8 files changed, 261 insertions, 53 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py

index 8143168107..264e625bd7 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -488,10 +489,12 @@ class DeviceInboxWorkerStore(SQLBaseStore): devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. + # We exclude hidden devices (such as cross-signing keys) here as they are + # not expected to receive to-device messages. devices = self.db_pool.simple_select_onecol_txn( txn, table="devices", - keyvalues={"user_id": user_id}, + keyvalues={"user_id": user_id, "hidden": False}, retcol="device_id", ) @@ -504,10 +507,12 @@ class DeviceInboxWorkerStore(SQLBaseStore): if not devices: continue + # We exclude hidden devices (such as cross-signing keys) here as they are + # not expected to receive to-device messages. rows = self.db_pool.simple_select_many_txn( txn, table="devices", - keyvalues={"user_id": user_id}, + keyvalues={"user_id": user_id, "hidden": False}, column="device_id", iterable=devices, retcols=("device_id",), @@ -555,6 +560,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox" + REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -570,6 +577,16 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) + self.db_pool.updates.register_background_update_handler( + self.REMOVE_DELETED_DEVICES, + self._remove_deleted_devices_from_device_inbox, + ) + + self.db_pool.updates.register_background_update_handler( + self.REMOVE_HIDDEN_DEVICES, + self._remove_hidden_devices_from_device_inbox, + ) + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() @@ -582,6 +599,172 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): return 1 + async def _remove_deleted_devices_from_device_inbox( + self, progress: JsonDict, batch_size: int + ) -> int: + """A background update that deletes all device_inboxes for deleted devices. + + This should only need to be run once (when users upgrade to v1.47.0) + + Args: + progress: JsonDict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + def _remove_deleted_devices_from_device_inbox_txn( + txn: LoggingTransaction, + ) -> int: + """stream_id is not unique + we need to use an inclusive `stream_id >= ?` clause, + since we might not have deleted all dead device messages for the stream_id + returned from the previous query + + Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, + to avoid problems of deleting a large number of rows all at once + due to a single device having lots of device messages. + """ + + last_stream_id = progress.get("stream_id", 0) + + sql = """ + SELECT device_id, user_id, stream_id + FROM device_inbox + WHERE + stream_id >= ? + AND (device_id, user_id) NOT IN ( + SELECT device_id, user_id FROM devices + ) + ORDER BY stream_id + LIMIT ? + """ + + txn.execute(sql, (last_stream_id, batch_size)) + rows = txn.fetchall() + + num_deleted = 0 + for row in rows: + num_deleted += self.db_pool.simple_delete_txn( + txn, + "device_inbox", + {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, + ) + + if rows: + # send more than stream_id to progress + # otherwise it can happen in large deployments that + # no change of status is visible in the log file + # it may be that the stream_id does not change in several runs + self.db_pool.updates._background_update_progress_txn( + txn, + self.REMOVE_DELETED_DEVICES, + { + "device_id": rows[-1][0], + "user_id": rows[-1][1], + "stream_id": rows[-1][2], + }, + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_deleted_devices_from_device_inbox", + _remove_deleted_devices_from_device_inbox_txn, + ) + + # The task is finished when no more lines are deleted. + if not number_deleted: + await self.db_pool.updates._end_background_update( + self.REMOVE_DELETED_DEVICES + ) + + return number_deleted + + async def _remove_hidden_devices_from_device_inbox( + self, progress: JsonDict, batch_size: int + ) -> int: + """A background update that deletes all device_inboxes for hidden devices. + + This should only need to be run once (when users upgrade to v1.47.0) + + Args: + progress: JsonDict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + def _remove_hidden_devices_from_device_inbox_txn( + txn: LoggingTransaction, + ) -> int: + """stream_id is not unique + we need to use an inclusive `stream_id >= ?` clause, + since we might not have deleted all hidden device messages for the stream_id + returned from the previous query + + Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, + to avoid problems of deleting a large number of rows all at once + due to a single device having lots of device messages. + """ + + last_stream_id = progress.get("stream_id", 0) + + sql = """ + SELECT device_id, user_id, stream_id + FROM device_inbox + WHERE + stream_id >= ? + AND (device_id, user_id) IN ( + SELECT device_id, user_id FROM devices WHERE hidden = ? + ) + ORDER BY stream_id + LIMIT ? + """ + + txn.execute(sql, (last_stream_id, True, batch_size)) + rows = txn.fetchall() + + num_deleted = 0 + for row in rows: + num_deleted += self.db_pool.simple_delete_txn( + txn, + "device_inbox", + {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, + ) + + if rows: + # We don't just save the `stream_id` in progress as + # otherwise it can happen in large deployments that + # no change of status is visible in the log file, as + # it may be that the stream_id does not change in several runs + self.db_pool.updates._background_update_progress_txn( + txn, + self.REMOVE_HIDDEN_DEVICES, + { + "device_id": rows[-1][0], + "user_id": rows[-1][1], + "stream_id": rows[-1][2], + }, + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_hidden_devices_from_device_inbox", + _remove_hidden_devices_from_device_inbox_txn, + ) + + # The task is finished when no more lines are deleted. + if not number_deleted: + await self.db_pool.updates._end_background_update( + self.REMOVE_HIDDEN_DEVICES + ) + + return number_deleted + class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): pass diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a01bf2c5b7..9ccc66e589 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore): user_ids: the users who were signed Returns: - THe new stream ID. + The new stream ID. """ async with self._device_list_id_gen.get_next() as stream_id: @@ -1134,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): raise StoreError(500, "Problem storing device.") async def delete_device(self, user_id: str, device_id: str) -> None: - """Delete a device. + """Delete a device and its device_inbox. Args: user_id: The ID of the user which owns the device device_id: The ID of the device to delete """ - await self.db_pool.simple_delete_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - desc="delete_device", - ) - self.device_id_exists_cache.invalidate((user_id, device_id)) + await self.delete_devices(user_id, [device_id]) async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. @@ -1155,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user which owns the devices device_ids: The IDs of the devices to delete """ - await self.db_pool.simple_delete_many( - table="devices", - column="device_id", - iterable=device_ids, - keyvalues={"user_id": user_id, "hidden": False}, - desc="delete_devices", - ) + + def _delete_devices_txn(txn: LoggingTransaction) -> None: + self.db_pool.simple_delete_many_txn( + txn, + table="devices", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id, "hidden": False}, + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="device_inbox", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) @@ -1315,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): async def add_device_change_to_streams( self, user_id: str, device_ids: Collection[str], hosts: List[str] - ): + ) -> int: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, ) @@ -494,7 +495,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, List[str]], + event_to_auth_chain: Dict[str, Sequence[str]], ) -> None: """Calculate the chain cover index for the given events. @@ -786,7 +787,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, List[str]], + event_to_auth_chain: Dict[str, Sequence[str]], events_to_calc_chain_id_for: Set[str], chain_map: Dict[str, Tuple[int, int]], ) -> Dict[str, Tuple[int, int]]: @@ -1794,7 +1795,7 @@ class PersistEventsStore: ) # Insert an edge for every prev_event connection - for prev_event_id in event.prev_events: + for prev_event_id in event.prev_event_ids(): self.db_pool.simple_insert_txn( txn, table="insertion_event_edges", diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): prefilled_cache=presence_cache_prefill, ) - async def update_presence(self, presence_states): + async def update_presence(self, presence_states) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index ba7075caa5..dd8e27e226 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def update_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str + self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] ) -> int: return await self.db_pool.simple_update( table="remote_profile_cache", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 40760fbd1b..53576ad52f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@ # limitations under the License. import logging -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import attr from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore): """ where_clause = ["relates_to_id = ?"] - where_args = [event_id] + where_args: List[Union[str, int]] = [event_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore): pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, + from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] + to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] engine=self.database_engine, ) @@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore): order, ) - def _get_recent_references_for_event_txn(txn): + def _get_recent_references_for_event_txn( + txn: LoggingTransaction, + ) -> PaginationChunk: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): """ where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args = [event_id, RelationTypes.ANNOTATION] + where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] if event_type: where_clause.append("type = ?") @@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore): having_clause = generate_pagination_where_clause( direction=direction, column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, + from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] + to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] engine=self.database_engine, ) @@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore): having_clause=having_clause, ) - def _get_aggregation_groups_for_event_txn(txn): + def _get_aggregation_groups_for_event_txn( + txn: LoggingTransaction, + ) -> PaginationChunk: txn.execute(sql, where_args + [limit + 1]) next_batch = None @@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore): LIMIT 1 """ - def _get_applicable_edit_txn(txn): + def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: txn.execute(sql, (event_id, RelationTypes.REPLACE)) row = txn.fetchone() if row: return row[0] + return None edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn @@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore): if not edit_id: return None - return await self.get_event(edit_id, allow_none=True) + return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined] @cached() async def get_thread_summary( @@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore): The number of items in the thread and the most recent response, if any. """ - def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: + def _get_thread_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[int, Optional[str]]: # Fetch the count of threaded events and the latest event ID. # TODO Should this only allow m.room.message events. sql = """ @@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore): AND relation_type = ? """ txn.execute(sql, (event_id, RelationTypes.THREAD)) - count = txn.fetchone()[0] + count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id @@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_event = None if latest_event_id: - latest_event = await self.get_event(latest_event_id, allow_none=True) + latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] return count, latest_event @@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore): LIMIT 1; """ - def _get_if_user_has_annotated_event(txn): + def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: txn.execute( sql, ( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore): limit: maximum amount of rooms to retrieve order_by: the sort order of the returned list reverse_order: whether to reverse the room list - search_term: a string to filter room names by + search_term: a string to filter room names, + canonical alias and room ids by. + Room ID must match exactly. Canonical alias must match a substring of the local part. Returns: A list of room dicts and an integer representing the total number of rooms that exist given this query """ # Filter room names by a string where_statement = "" + search_pattern = [] if search_term: - where_statement = "WHERE LOWER(state.name) LIKE ?" + where_statement = """ + WHERE LOWER(state.name) LIKE ? + OR LOWER(state.canonical_alias) LIKE ? + OR state.room_id = ? + """ # Our postgres db driver converts ? -> %s in SQL strings as that's the # placeholder for postgres. # HOWEVER, if you put a % into your SQL then everything goes wibbly. # To get around this, we're going to surround search_term with %'s # before giving it to the database in python instead - search_term = "%" + search_term.lower() + "%" + search_pattern = [ + "%" + search_term.lower() + "%", + "#%" + search_term.lower() + "%:%", + search_term, + ] # Set ordering if RoomSortOrder(order_by) == RoomSortOrder.SIZE: @@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore): ) def _get_rooms_paginate_txn(txn): - # Execute the data query - sql_values = (limit, start) - if search_term: - # Add the search term into the WHERE clause - sql_values = (search_term,) + sql_values - txn.execute(info_sql, sql_values) + # Add the search term into the WHERE clause + # and execute the data query + txn.execute(info_sql, search_pattern + [limit, start]) # Refactor room query data into a structured dictionary rooms = [] @@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore): # Execute the count query # Add the search term into the WHERE clause if present - sql_values = (search_term,) if search_term else () - txn.execute(count_sql, sql_values) + txn.execute(count_sql, search_pattern) room_count = txn.fetchone() return rooms, room_count[0] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): async def get_joined_users_from_context( self, event: EventBase, context: EventContext - ): + ) -> Dict[str, ProfileInfo]: state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): event.room_id, state_group, current_state_ids, event=event, context=context ) - async def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state( + self, room_id, state_entry + ) -> Dict[str, ProfileInfo]: state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cache_context, event=None, context=None, - ): + ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different.