From bcff01b40673238dca29c0f22dc4fda05f635030 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 18 Oct 2023 17:42:01 +0200 Subject: Improve performance of delete device messages query (#16492) --- synapse/storage/databases/main/deviceinbox.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1faa6f04b2..3e7425d4a6 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -478,18 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": "No changes in cache since last check"}) return 0 - ROW_ID_NAME = self.database_engine.row_id_name - def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: limit_statement = "" if limit is None else f"LIMIT {limit}" sql = f""" - DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN ( - SELECT {ROW_ID_NAME} FROM device_inbox - WHERE user_id = ? AND device_id = ? AND stream_id <= ? - {limit_statement} + DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= ( + SELECT MAX(stream_id) FROM ( + SELECT stream_id FROM device_inbox + WHERE user_id = ? AND device_id = ? AND stream_id <= ? + ORDER BY stream_id + {limit_statement} + ) AS q1 ) """ - txn.execute(sql, (user_id, device_id, up_to_stream_id)) + txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id)) return txn.rowcount count = await self.db_pool.runInteraction( -- cgit 1.5.1 From 49c9745b4516dec8728c260f1a6784f2c510110c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 18 Oct 2023 12:26:01 -0400 Subject: Avoid sending massive replication updates when purging a room. (#16510) --- changelog.d/16510.misc | 1 + synapse/replication/tcp/streams/events.py | 45 +++++++++++++- synapse/storage/databases/main/cache.py | 8 +++ tests/replication/tcp/streams/test_events.py | 91 +++++++++++++++++++--------- 4 files changed, 115 insertions(+), 30 deletions(-) create mode 100644 changelog.d/16510.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/16510.misc b/changelog.d/16510.misc new file mode 100644 index 0000000000..5556b5d74c --- /dev/null +++ b/changelog.d/16510.misc @@ -0,0 +1 @@ +Improve replication performance when purging rooms. diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index ad9b760713..da6d948e1b 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq +from collections import defaultdict from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast import attr @@ -51,8 +52,19 @@ data part are: * The state_key of the state which has changed * The event id of the new state +A "state-all" row is sent whenever the "current state" in a room changes, but there are +too many state updates for a particular room in the same update. This replaces any +"state" rows on a per-room basis. The fields in the data part are: + +* The room id for the state changes + """ +# Any room with more than _MAX_STATE_UPDATES_PER_ROOM will send a EventsStreamAllStateRow +# instead of individual EventsStreamEventRow. This is predominantly useful when +# purging large rooms. +_MAX_STATE_UPDATES_PER_ROOM = 150 + @attr.s(slots=True, frozen=True, auto_attribs=True) class EventsStreamRow: @@ -111,9 +123,17 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id: Optional[str] +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventsStreamAllStateRow(BaseEventsStreamRow): + TypeId = "state-all" + + room_id: str + + _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( EventsStreamEventRow, EventsStreamCurrentStateRow, + EventsStreamAllStateRow, ) TypeToRow = {Row.TypeId: Row for Row in _EventRows} @@ -213,9 +233,28 @@ class EventsStream(Stream): if stream_id <= upper_limit ) + # Separate out rooms that have many state updates, listeners should clear + # all state for those rooms. + state_updates_by_room = defaultdict(list) + for stream_id, room_id, _type, _state_key, _event_id in state_rows: + state_updates_by_room[room_id].append(stream_id) + + state_all_rows = [ + (stream_ids[-1], room_id) + for room_id, stream_ids in state_updates_by_room.items() + if len(stream_ids) >= _MAX_STATE_UPDATES_PER_ROOM + ] + state_all_updates: Iterable[Tuple[int, Tuple]] = ( + (max_stream_id, (EventsStreamAllStateRow.TypeId, (room_id,))) + for (max_stream_id, room_id) in state_all_rows + ) + + # Any remaining state updates are sent individually. + state_all_rooms = {room_id for _, room_id in state_all_rows} state_updates: Iterable[Tuple[int, Tuple]] = ( (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) for (stream_id, *rest) in state_rows + if rest[0] not in state_all_rooms ) ex_outliers_updates: Iterable[Tuple[int, Tuple]] = ( @@ -224,7 +263,11 @@ class EventsStream(Stream): ) # we need to return a sorted list, so merge them together. - updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) + updates = list( + heapq.merge( + event_updates, state_all_updates, state_updates, ex_outliers_updates + ) + ) return updates, upper_limit, limited @classmethod diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2fbd389c71..4d0470ffd9 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams.events import ( EventsStream, + EventsStreamAllStateRow, EventsStreamCurrentStateRow, EventsStreamEventRow, EventsStreamRow, @@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore): (data.state_key,) ) self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined] + elif row.type == EventsStreamAllStateRow.TypeId: + assert isinstance(data, EventsStreamAllStateRow) + # Similar to the above, but the entire caches are invalidated. This is + # unfortunate for the membership caches, but should recover quickly. + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined] + self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined] + self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined] else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 128fc3e046..b8ab4ee54b 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -14,6 +14,8 @@ from typing import Any, List, Optional +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership @@ -21,6 +23,8 @@ from synapse.events import EventBase from synapse.replication.tcp.commands import RdataCommand from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams.events import ( + _MAX_STATE_UPDATES_PER_ROOM, + EventsStreamAllStateRow, EventsStreamCurrentStateRow, EventsStreamEventRow, EventsStreamRow, @@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) - def test_update_function_huge_state_change(self) -> None: + @parameterized.expand( + [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)] + ) + def test_update_function_huge_state_change( + self, num_state_changes: int, collapse_state_changes: bool + ) -> None: """Test replication with many state events Ensures that all events are correctly replicated when there are lots of state change rows to be replicated. + + Args: + num_state_changes: The number of state changes to create. + collapse_state_changes: Whether the state changes are expected to be + collapsed or not. """ # we want to generate lots of state changes at a single stream ID. @@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase): events = [ self._inject_state_event(sender=OTHER_USER) - for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + for _ in range(num_state_changes) ] self.replicate() @@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase): row for row in self.test_handler.received_rdata_rows if row[0] == "events" ] - # first check the first two rows, which should be state1 - + # first check the first two rows, which should be the state1 event. stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertEqual(row.data.event_id, state1.event_id) - # now the last two rows, which should be state2 + # now the last two rows, which should be the state2 event. stream_name, token, row = received_rows.pop(-2) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -231,34 +244,54 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertEqual(row.data.event_id, state2.event_id) - # that should leave us with the rows for the PL event - self.assertEqual(len(received_rows), len(events) + 2) + # Based on the number of + if collapse_state_changes: + # that should leave us with the rows for the PL event, the state changes + # get collapsed into a single row. + self.assertEqual(len(received_rows), 2) - stream_name, token, row = received_rows.pop(0) - self.assertEqual("events", stream_name) - self.assertIsInstance(row, EventsStreamRow) - self.assertEqual(row.type, "ev") - self.assertIsInstance(row.data, EventsStreamEventRow) - self.assertEqual(row.data.event_id, pl_event.event_id) + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) - # the state rows are unsorted - state_rows: List[EventsStreamCurrentStateRow] = [] - for stream_name, _, row in received_rows: + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state-all") + self.assertIsInstance(row.data, EventsStreamAllStateRow) + self.assertEqual(row.data.room_id, state2.room_id) + + else: + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) - self.assertEqual(row.type, "state") - self.assertIsInstance(row.data, EventsStreamCurrentStateRow) - state_rows.append(row.data) - - state_rows.sort(key=lambda r: r.state_key) - - sr = state_rows.pop(0) - self.assertEqual(sr.type, EventTypes.PowerLevels) - self.assertEqual(sr.event_id, pl_event.event_id) - for sr in state_rows: - self.assertEqual(sr.type, "test_state_event") - # "None" indicates the state has been deleted - self.assertIsNone(sr.event_id) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows: List[EventsStreamCurrentStateRow] = [] + for stream_name, _, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) def test_update_function_state_row_limit(self) -> None: """Test replication with many state events over several stream ids.""" -- cgit 1.5.1 From e9069c9f919685606506f04527332e83fbfa44d9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Oct 2023 15:04:18 +0100 Subject: Mark sync as limited if there is a gap in the timeline (#16485) This splits thinsg into two queries, but most of the time we won't have new event backwards extremities so this shouldn't actually add an extra RTT for the majority of cases. Note this removes the check for events with no prev events, but that was part of MSC2716 work that has since been removed. --- changelog.d/16485.bugfix | 1 + synapse/handlers/sync.py | 52 ++++++++++++++--- synapse/storage/databases/main/events.py | 74 ++++++++++++++++--------- synapse/storage/databases/main/stream.py | 47 ++++++++++++++++ synapse/storage/schema/main/delta/82/05gaps.sql | 25 +++++++++ 5 files changed, 166 insertions(+), 33 deletions(-) create mode 100644 changelog.d/16485.bugfix create mode 100644 synapse/storage/schema/main/delta/82/05gaps.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/16485.bugfix b/changelog.d/16485.bugfix new file mode 100644 index 0000000000..3cd7e1877f --- /dev/null +++ b/changelog.d/16485.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where `/sync` incorrectly did not mark a room as `limited` in a sync requests when there were missing remote events. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 60b4d95cd7..f131c0e8e0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -500,12 +500,27 @@ class SyncHandler: async def _load_filtered_recents( self, room_id: str, + sync_result_builder: "SyncResultBuilder", sync_config: SyncConfig, - now_token: StreamToken, + upto_token: StreamToken, since_token: Optional[StreamToken] = None, potential_recents: Optional[List[EventBase]] = None, newly_joined_room: bool = False, ) -> TimelineBatch: + """Create a timeline batch for the room + + Args: + room_id + sync_result_builder + sync_config + upto_token: The token up to which we should fetch (more) events. + If `potential_results` is non-empty then this is *start* of + the the list. + since_token + potential_recents: If non-empty, the events between the since token + and current token to send down to clients. + newly_joined_room + """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() block_all_timeline = ( @@ -521,6 +536,20 @@ class SyncHandler: else: limited = False + # Check if there is a gap, if so we need to mark this as limited and + # recalculate which events to send down. + gap_token = await self.store.get_timeline_gaps( + room_id, + since_token.room_key if since_token else None, + sync_result_builder.now_token.room_key, + ) + if gap_token: + # There's a gap, so we need to ignore the passed in + # `potential_recents`, and reset `upto_token` to match. + potential_recents = None + upto_token = sync_result_builder.now_token + limited = True + log_kv({"limited": limited}) if potential_recents: @@ -559,10 +588,10 @@ class SyncHandler: recents = [] if not limited or block_all_timeline: - prev_batch_token = now_token + prev_batch_token = upto_token if recents: room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace( + prev_batch_token = upto_token.copy_and_replace( StreamKeyType.ROOM, room_key ) @@ -573,11 +602,15 @@ class SyncHandler: filtering_factor = 2 load_limit = max(timeline_limit * filtering_factor, 10) max_repeat = 5 # Only try a few times per room, otherwise - room_key = now_token.room_key + room_key = upto_token.room_key end_key = room_key since_key = None - if since_token and not newly_joined_room: + if since_token and gap_token: + # If there is a gap then we need to only include events after + # it. + since_key = gap_token + elif since_token and not newly_joined_room: since_key = since_token.room_key while limited and len(recents) < timeline_limit and max_repeat: @@ -647,7 +680,7 @@ class SyncHandler: recents = recents[-timeline_limit:] room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key) + prev_batch_token = upto_token.copy_and_replace(StreamKeyType.ROOM, room_key) # Don't bother to bundle aggregations if the timeline is unlimited, # as clients will have all the necessary information. @@ -662,7 +695,9 @@ class SyncHandler: return TimelineBatch( events=recents, prev_batch=prev_batch_token, - limited=limited or newly_joined_room, + # Also mark as limited if this is a new room or there has been a gap + # (to force client to paginate the gap). + limited=limited or newly_joined_room or gap_token is not None, bundled_aggregations=bundled_aggregations, ) @@ -2397,8 +2432,9 @@ class SyncHandler: batch = await self._load_filtered_recents( room_id, + sync_result_builder, sync_config, - now_token=upto_token, + upto_token=upto_token, since_token=since_token, potential_recents=events, newly_joined_room=newly_joined, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ef6766b5e0..3c1492e3ad 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2267,35 +2267,59 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ - # From the events passed in, add all of the prev events as backwards extremities. - # Ignore any events that are already backwards extrems or outliers. - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - # 1. Don't add an event as a extremity again if we already persisted it - # as a non-outlier. - # 2. Don't add an outlier as an extremity if it has no prev_events - " AND NOT EXISTS (" - " SELECT 1 FROM events" - " LEFT JOIN event_edges edge" - " ON edge.event_id = events.event_id" - " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)" - " )" + + room_id = events[0].room_id + + potential_backwards_extremities = { + e_id + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + } + + if not potential_backwards_extremities: + return + + existing_events_outliers = self.db_pool.simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=potential_backwards_extremities, + keyvalues={"outlier": False}, + retcols=("event_id",), ) - txn.execute_batch( - query, - [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id) - for ev in events - for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ], + potential_backwards_extremities.difference_update( + e for e, in existing_events_outliers ) + if potential_backwards_extremities: + self.db_pool.simple_upsert_many_txn( + txn, + table="event_backward_extremities", + key_names=("room_id", "event_id"), + key_values=[(room_id, ev) for ev in potential_backwards_extremities], + value_names=(), + value_values=(), + ) + + # Record the stream orderings where we have new gaps. + gap_events = [ + (room_id, self._instance_name, ev.internal_metadata.stream_ordering) + for ev in events + if any( + e_id in potential_backwards_extremities + for e_id in ev.prev_event_ids() + ) + ] + + self.db_pool.simple_insert_many_txn( + txn, + table="timeline_gaps", + keys=("room_id", "instance_name", "stream_ordering"), + values=gap_events, + ) + # Delete all these events that we've already fetched and now know that their # prev events are the new backwards extremeties. query = ( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ea06e4eee0..872df6bda1 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1616,3 +1616,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcol="instance_name", desc="get_name_from_instance_id", ) + + async def get_timeline_gaps( + self, + room_id: str, + from_token: Optional[RoomStreamToken], + to_token: RoomStreamToken, + ) -> Optional[RoomStreamToken]: + """Check if there is a gap, and return a token that marks the position + of the gap in the stream. + """ + + sql = """ + SELECT instance_name, stream_ordering + FROM timeline_gaps + WHERE room_id = ? AND ? < stream_ordering AND stream_ordering <= ? + ORDER BY stream_ordering + """ + + rows = await self.db_pool.execute( + "get_timeline_gaps", + None, + sql, + room_id, + from_token.stream if from_token else 0, + to_token.get_max_stream_pos(), + ) + + if not rows: + return None + + positions = [ + PersistedEventPosition(instance_name, stream_ordering) + for instance_name, stream_ordering in rows + ] + if from_token: + positions = [p for p in positions if p.persisted_after(from_token)] + + positions = [p for p in positions if not p.persisted_after(to_token)] + + if positions: + # We return a stream token that ensures the event *at* the position + # of the gap is included (as the gap is *before* the persisted + # event). + last_position = positions[-1] + return RoomStreamToken(stream=last_position.stream - 1) + + return None diff --git a/synapse/storage/schema/main/delta/82/05gaps.sql b/synapse/storage/schema/main/delta/82/05gaps.sql new file mode 100644 index 0000000000..6813b488ca --- /dev/null +++ b/synapse/storage/schema/main/delta/82/05gaps.sql @@ -0,0 +1,25 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Records when we see a "gap in the timeline", due to missing events over +-- federation. We record this so that we can tell clients there is a gap (by +-- marking the timeline section of a sync request as limited). +CREATE TABLE IF NOT EXISTS timeline_gaps ( + room_id TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX timeline_gaps_room_id ON timeline_gaps(room_id, stream_ordering); -- cgit 1.5.1 From 12ca87f5eac06450abaf024e5f4906147d5322e3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Oct 2023 07:37:45 -0400 Subject: Remove the last reference to event_txn_id. (#16521) This table was no longer used, except for a background process which purged old entries in it. --- changelog.d/16521.misc | 1 + synapse/storage/databases/main/events_worker.py | 6 ------ synapse/storage/schema/__init__.py | 5 ++++- 3 files changed, 5 insertions(+), 7 deletions(-) create mode 100644 changelog.d/16521.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/16521.misc b/changelog.d/16521.misc new file mode 100644 index 0000000000..c6a8ddcf9c --- /dev/null +++ b/changelog.d/16521.misc @@ -0,0 +1 @@ +Stop deleting from an unused table. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8af638d60f..5bf864c1fb 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -2095,12 +2095,6 @@ class EventsWorkerStore(SQLBaseStore): def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 - sql = """ - DELETE FROM event_txn_id - WHERE inserted_ts < ? - """ - txn.execute(sql, (one_day_ago,)) - sql = """ DELETE FROM event_txn_id_device_id WHERE inserted_ts < ? diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 5b50bd66bc..158b528dce 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 82 # remember to update the list below when updating +SCHEMA_VERSION = 83 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -121,6 +121,9 @@ Changes in SCHEMA_VERSION = 81 Changes in SCHEMA_VERSION = 82 - The insertion_events, insertion_event_extremities, insertion_event_edges, and batch_events tables are no longer purged in preparation for their removal. + +Changes in SCHEMA_VERSION = 83 + - The event_txn_id is no longer used. """ -- cgit 1.5.1 From 3bc23cc45cb6a70d53ba4032a9116029bc4f538c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Oct 2023 14:39:25 +0100 Subject: Fix bug that could cause a `/sync` to tightloop with sqlite after restart (#16540) This could happen if the last rows in the account data stream were inserted into `account_data`. After a restart the max account ID would be calculated without looking at the `account_data` table, and so have an old ID. --- changelog.d/16540.bugfix | 1 + synapse/storage/databases/main/account_data.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/16540.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/16540.bugfix b/changelog.d/16540.bugfix new file mode 100644 index 0000000000..34ee9facf9 --- /dev/null +++ b/changelog.d/16540.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where `/sync` could tightloop after restart when using SQLite. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 39498d52c6..84ef8136c2 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) hs.get_replication_notifier(), "room_account_data", "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], + extra_tables=[ + ("account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ], is_writer=self._instance_name in hs.config.worker.writers.account_data, ) -- cgit 1.5.1 From ba47fea5286e084ec70d568aa62eb4820b857c47 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Oct 2023 16:16:19 +0100 Subject: Allow multiple workers to write to receipts stream. (#16432) Fixes #16417 --- changelog.d/16432.feature | 1 + synapse/config/workers.py | 4 +- synapse/handlers/appservice.py | 42 ++-- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/receipts.py | 19 +- synapse/handlers/sync.py | 7 +- synapse/notifier.py | 45 +++- synapse/replication/tcp/client.py | 3 +- synapse/storage/databases/main/receipts.py | 148 +++++++++---- synapse/storage/databases/main/relations.py | 4 +- .../delta/83/03_instance_name_receipts.sql.sqlite | 17 ++ synapse/streams/events.py | 4 +- synapse/types/__init__.py | 137 +++++++++++- tests/handlers/test_appservice.py | 17 +- tests/replication/test_sharded_receipts.py | 243 +++++++++++++++++++++ 15 files changed, 604 insertions(+), 89 deletions(-) create mode 100644 changelog.d/16432.feature create mode 100644 synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite create mode 100644 tests/replication/test_sharded_receipts.py (limited to 'synapse/storage/databases') diff --git a/changelog.d/16432.feature b/changelog.d/16432.feature new file mode 100644 index 0000000000..9a76e85592 --- /dev/null +++ b/changelog.d/16432.feature @@ -0,0 +1 @@ +Allow multiple workers to write to receipts stream. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f1766088fc..6d67a8cd5c 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -358,9 +358,9 @@ class WorkerConfig(Config): "Must only specify one instance to handle `account_data` messages." ) - if len(self.writers.receipts) != 1: + if len(self.writers.receipts) == 0: raise ConfigError( - "Must only specify one instance to handle `receipts` messages." + "Must specify at least one instance to handle `receipts` messages." ) if len(self.writers.events) == 0: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index c200a45f3a..873dadc3bd 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,6 +47,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, JsonMapping, + MultiWriterStreamToken, RoomAlias, RoomStreamToken, StreamKeyType, @@ -217,7 +218,7 @@ class ApplicationServicesHandler: def notify_interested_services_ephemeral( self, stream_key: StreamKeyType, - new_token: Union[int, RoomStreamToken], + new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], users: Collection[Union[str, UserID]], ) -> None: """ @@ -259,19 +260,6 @@ class ApplicationServicesHandler: ): return - # Assert that new_token is an integer (and not a RoomStreamToken). - # All of the supported streams that this function handles use an - # integer to track progress (rather than a RoomStreamToken - a - # vector clock implementation) as they don't support multiple - # stream writers. - # - # As a result, we simply assert that new_token is an integer. - # If we do end up needing to pass a RoomStreamToken down here - # in the future, using RoomStreamToken.stream (the minimum stream - # position) to convert to an ascending integer value should work. - # Additional context: https://github.com/matrix-org/synapse/pull/11137 - assert isinstance(new_token, int) - # Ignore to-device messages if the feature flag is not enabled if ( stream_key == StreamKeyType.TO_DEVICE @@ -286,6 +274,9 @@ class ApplicationServicesHandler: ): return + # We know we're not a `RoomStreamToken` at this point. + assert not isinstance(new_token, RoomStreamToken) + # Check whether there are any appservices which have registered to receive # ephemeral events. # @@ -327,7 +318,7 @@ class ApplicationServicesHandler: self, services: List[ApplicationService], stream_key: StreamKeyType, - new_token: int, + new_token: Union[int, MultiWriterStreamToken], users: Collection[Union[str, UserID]], ) -> None: logger.debug("Checking interested services for %s", stream_key) @@ -340,6 +331,7 @@ class ApplicationServicesHandler: # # Instead we simply grab the latest typing updates in _handle_typing # and, if they apply to this application service, send it off. + assert isinstance(new_token, int) events = await self._handle_typing(service, new_token) if events: self.scheduler.enqueue_for_appservice(service, ephemeral=events) @@ -350,15 +342,23 @@ class ApplicationServicesHandler: (service.id, stream_key) ): if stream_key == StreamKeyType.RECEIPT: + assert isinstance(new_token, MultiWriterStreamToken) + + # We store appservice tokens as integers, so we ignore + # the `instance_map` components and instead simply + # follow the base stream position. + new_token = MultiWriterStreamToken(stream=new_token.stream) + events = await self._handle_receipts(service, new_token) self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice await self.store.set_appservice_stream_type_pos( - service, "read_receipt", new_token + service, "read_receipt", new_token.stream ) elif stream_key == StreamKeyType.PRESENCE: + assert isinstance(new_token, int) events = await self._handle_presence(service, users, new_token) self.scheduler.enqueue_for_appservice(service, ephemeral=events) @@ -368,6 +368,7 @@ class ApplicationServicesHandler: ) elif stream_key == StreamKeyType.TO_DEVICE: + assert isinstance(new_token, int) # Retrieve a list of to-device message events, as well as the # maximum stream token of the messages we were able to retrieve. to_device_messages = await self._get_to_device_messages( @@ -383,6 +384,7 @@ class ApplicationServicesHandler: ) elif stream_key == StreamKeyType.DEVICE_LIST: + assert isinstance(new_token, int) device_list_summary = await self._get_device_list_summary( service, new_token ) @@ -432,7 +434,7 @@ class ApplicationServicesHandler: return typing async def _handle_receipts( - self, service: ApplicationService, new_token: int + self, service: ApplicationService, new_token: MultiWriterStreamToken ) -> List[JsonMapping]: """ Return the latest read receipts that the given application service should receive. @@ -455,15 +457,17 @@ class ApplicationServicesHandler: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) - if new_token is not None and new_token <= from_key: + if new_token is not None and new_token.stream <= from_key: logger.debug( "Rejecting token lower than or equal to stored: %s" % (new_token,) ) return [] + from_token = MultiWriterStreamToken(stream=from_key) + receipts_source = self.event_sources.sources.receipt receipts, _ = await receipts_source.get_new_events_as( - service=service, from_key=from_key, to_key=new_token + service=service, from_key=from_token, to_key=new_token ) return receipts diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index c34bd7db95..b1d8be866f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -145,7 +145,7 @@ class InitialSyncHandler: joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] receipt = await self.store.get_linearized_receipts_for_rooms( joined_rooms, - to_key=int(now_token.receipt_key), + to_key=now_token.receipt_key, ) receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 69ac468f75..b5f7a8b47e 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -20,6 +20,7 @@ from synapse.streams import EventSource from synapse.types import ( JsonDict, JsonMapping, + MultiWriterStreamToken, ReadReceipt, StreamKeyType, UserID, @@ -200,7 +201,7 @@ class ReceiptsHandler: await self.federation_sender.send_read_receipt(receipt) -class ReceiptEventSource(EventSource[int, JsonMapping]): +class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.config = hs.config @@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): async def get_new_events( self, user: UserID, - from_key: int, + from_key: MultiWriterStreamToken, limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonMapping], int]: - from_key = int(from_key) + ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: to_key = self.get_current_key() if from_key == to_key: @@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): return events, to_key async def get_new_events_as( - self, from_key: int, to_key: int, service: ApplicationService - ) -> Tuple[List[JsonMapping], int]: + self, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, + service: ApplicationService, + ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: """Returns a set of new read receipt events that an appservice may be interested in. @@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): appservice may be interested in. * The current read receipt stream token. """ - from_key = int(from_key) - if from_key == to_key: return [], to_key @@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): return events, to_key - def get_current_key(self) -> int: + def get_current_key(self) -> MultiWriterStreamToken: return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f131c0e8e0..f75c1548ca 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -57,6 +57,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, JsonMapping, + MultiWriterStreamToken, MutableStateMap, Requester, RoomStreamToken, @@ -477,7 +478,11 @@ class SyncHandler: event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) - receipt_key = since_token.receipt_key if since_token else 0 + receipt_key = ( + since_token.receipt_key + if since_token + else MultiWriterStreamToken(stream=0) + ) receipt_source = self.event_sources.sources.receipt receipts, receipt_key = await receipt_source.get_new_events( diff --git a/synapse/notifier.py b/synapse/notifier.py index 99e7715896..ee0bd84f1e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -21,11 +21,13 @@ from typing import ( Dict, Iterable, List, + Literal, Optional, Set, Tuple, TypeVar, Union, + overload, ) import attr @@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + MultiWriterStreamToken, PersistedEventPosition, RoomStreamToken, StrCollection, @@ -127,7 +130,7 @@ class _NotifierUserStream: def notify( self, stream_key: StreamKeyType, - stream_id: Union[int, RoomStreamToken], + stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken], time_now_ms: int, ) -> None: """Notify any listeners for this user of a new event from an @@ -452,10 +455,48 @@ class Notifier: except Exception: logger.exception("Error pusher pool of event") + @overload + def on_new_event( + self, + stream_key: Literal[StreamKeyType.ROOM], + new_token: RoomStreamToken, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[StrCollection] = None, + ) -> None: + ... + + @overload + def on_new_event( + self, + stream_key: Literal[StreamKeyType.RECEIPT], + new_token: MultiWriterStreamToken, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[StrCollection] = None, + ) -> None: + ... + + @overload + def on_new_event( + self, + stream_key: Literal[ + StreamKeyType.ACCOUNT_DATA, + StreamKeyType.DEVICE_LIST, + StreamKeyType.PRESENCE, + StreamKeyType.PUSH_RULES, + StreamKeyType.TO_DEVICE, + StreamKeyType.TYPING, + StreamKeyType.UN_PARTIAL_STATED_ROOMS, + ], + new_token: int, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[StrCollection] = None, + ) -> None: + ... + def on_new_event( self, stream_key: StreamKeyType, - new_token: Union[int, RoomStreamToken], + new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, rooms: Optional[StrCollection] = None, ) -> None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 384355698d..1312b6f21e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -126,8 +126,9 @@ class ReplicationDataHandler: StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) elif stream_name == ReceiptsStream.NAME: + new_token = self.store.get_max_receipt_stream_id() self.notifier.on_new_event( - StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] + StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows] ) await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) elif stream_name == ToDeviceStream.NAME: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b2645ab43c..56e8eb16a8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -28,6 +28,8 @@ from typing import ( cast, ) +from immutabledict import immutabledict + from synapse.api.constants import EduTypes from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, JsonMapping +from synapse.types import ( + JsonDict, + JsonMapping, + MultiWriterStreamToken, + PersistedPosition, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "receipts_linearized", entity_column="room_id", stream_column="stream_id", - max_value=max_receipts_stream_id, + max_value=max_receipts_stream_id.stream, limit=10000, ) self._receipts_stream_cache = StreamChangeCache( @@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - def get_max_receipt_stream_id(self) -> int: + def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: """Get the current max stream ID for receipts stream""" - return self._receipts_id_gen.get_current_token() + + min_pos = self._receipts_id_gen.get_current_token() + + positions = {} + if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._receipts_id_gen.get_positions().items() + if p > min_pos + } + + return MultiWriterStreamToken( + stream=min_pos, instance_map=immutabledict(positions) + ) + + def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: + return self._receipts_id_gen.get_current_token_for_writer(instance_name) def get_last_unthreaded_receipt_for_user_txn( self, @@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore): } async def get_linearized_receipts_for_rooms( - self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Iterable[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. @@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( - room_ids, from_key + room_ids, from_key.stream ) results = await self._get_linearized_receipts_for_rooms( @@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. @@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore): if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. - if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + if not self._receipts_stream_cache.has_entity_changed( + room_id, from_key.stream + ): return [] return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cached(tree=True) async def _get_linearized_receipts_for_room( - self, room_id: str, to_key: int, from_key: Optional[int] = None + self, + room_id: str, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: if from_key: - sql = ( - "SELECT receipt_type, user_id, event_id, data" - " FROM receipts_linearized WHERE" - " room_id = ? AND stream_id > ? AND stream_id <= ?" - ) + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE room_id = ? AND stream_id > ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, from_key, to_key)) - else: - sql = ( - "SELECT receipt_type, user_id, event_id, data" - " FROM receipts_linearized WHERE" - " room_id = ? AND stream_id <= ?" + txn.execute( + sql, (room_id, from_key.stream, to_key.get_max_stream_pos()) ) + else: + sql = """ + SELECT stream_id, instance_name, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE + room_id = ? AND stream_id <= ? + """ - txn.execute(sql, (room_id, to_key)) + txn.execute(sql, (room_id, to_key.get_max_stream_pos())) - return cast(List[Tuple[str, str, str, str]], txn.fetchall()) + return [ + (receipt_type, user_id, event_id, data) + for stream_id, instance_name, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) @@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=3, ) async def _get_linearized_receipts_for_rooms( - self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + self, + room_ids: Collection[str], + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} @@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> List[Tuple[str, str, str, str, Optional[str], str]]: if from_key: sql = """ - SELECT room_id, receipt_type, user_id, event_id, thread_id, data + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ @@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [from_key, to_key] + list(args)) + txn.execute( + sql + clause, + [from_key.stream, to_key.get_max_stream_pos()] + list(args), + ) else: sql = """ - SELECT room_id, receipt_type, user_id, event_id, thread_id, data + SELECT stream_id, instance_name, room_id, receipt_type, + user_id, event_id, thread_id, data FROM receipts_linearized WHERE stream_id <= ? AND """ @@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore): self.database_engine, "room_id", room_ids ) - txn.execute(sql + clause, [to_key] + list(args)) + txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return cast( - List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall() - ) + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f @@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore): num_args=2, ) async def get_linearized_receipts_for_all_rooms( - self, to_key: int, from_key: Optional[int] = None + self, + to_key: MultiWriterStreamToken, + from_key: Optional[MultiWriterStreamToken] = None, ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. @@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore): def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: if from_key: sql = """ - SELECT room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [from_key, to_key]) + txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) else: sql = """ - SELECT room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ - txn.execute(sql, [to_key]) + txn.execute(sql, [to_key.get_max_stream_pos()]) - return cast(List[Tuple[str, str, str, str, str]], txn.fetchall()) + return [ + (room_id, receipt_type, user_id, event_id, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + from_key, to_key, instance_name, stream_id + ) + ] txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f @@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) updates = cast( List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], @@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore): keyvalues=keyvalues, values={ "stream_id": stream_id, + "instance_name": self._instance_name, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), @@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_ids: List[str], thread_id: Optional[str], data: dict, - ) -> Optional[int]: + ) -> Optional[PersistedPosition]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - return stream_id + return PersistedPosition(self._instance_name, stream_id) async def _insert_graph_receipt( self, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7f40e2c446..ce7bfd5146 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import ( generate_pagination_where_clause, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, StreamKeyType, StreamToken +from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore): room_key=next_key, presence_key=0, typing_key=0, - receipt_key=0, + receipt_key=MultiWriterStreamToken(stream=0), account_data_key=0, push_rules_key=0, to_device_key=0, diff --git a/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite new file mode 100644 index 0000000000..6c7ad0fd37 --- /dev/null +++ b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This already exists on Postgres. +ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 609a0978a9..d0bb83b184 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.logging.opentracing import trace from synapse.streams import EventSource -from synapse.types import StreamKeyType, StreamToken +from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer @@ -111,7 +111,7 @@ class EventSources: room_key=await self.sources.room.get_current_key_for_room(room_id), presence_key=0, typing_key=0, - receipt_key=0, + receipt_key=MultiWriterStreamToken(stream=0), account_data_key=0, push_rules_key=0, to_device_key=0, diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 09a88c86a7..4c5b26ad93 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): return "s%d" % (self.stream,) +@attr.s(frozen=True, slots=True, order=False) +class MultiWriterStreamToken(AbstractMultiWriterStreamToken): + """A basic stream token class for streams that supports multiple writers.""" + + @classmethod + async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken": + try: + if string[0].isdigit(): + return cls(stream=int(string)) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) + instance_map[instance_name] = pos + + return cls( + stream=stream, + instance_map=immutabledict(instance_map), + ) + except CancelledError: + raise + except Exception: + pass + raise SynapseError(400, "Invalid stream token %r" % (string,)) + + async def to_string(self, store: "DataStore") -> str: + if self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + if pos <= self.stream: + # Ignore instances who are below the minimum stream position + # (we might know they've advanced without seeing a recent + # write from them). + continue + + instance_id = await store.get_id_for_instance(name) + entries.append(f"{instance_id}.{pos}") + + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + else: + return str(self.stream) + + @staticmethod + def is_stream_position_in_range( + low: Optional["AbstractMultiWriterStreamToken"], + high: Optional["AbstractMultiWriterStreamToken"], + instance_name: Optional[str], + pos: int, + ) -> bool: + """Checks if a given persisted position is between the two given tokens. + + If `instance_name` is None then the row was persisted before multi + writer support. + """ + + if low: + if instance_name: + low_stream = low.instance_map.get(instance_name, low.stream) + else: + low_stream = low.stream + + if pos <= low_stream: + return False + + if high: + if instance_name: + high_stream = high.instance_map.get(instance_name, high.stream) + else: + high_stream = high.stream + + if high_stream < pos: + return False + + return True + + class StreamKeyType(Enum): """Known stream types. @@ -776,7 +860,9 @@ class StreamToken: ) presence_key: int typing_key: int - receipt_key: int + receipt_key: MultiWriterStreamToken = attr.ib( + validator=attr.validators.instance_of(MultiWriterStreamToken) + ) account_data_key: int push_rules_key: int to_device_key: int @@ -799,8 +885,31 @@ class StreamToken: while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") + + ( + room_key, + presence_key, + typing_key, + receipt_key, + account_data_key, + push_rules_key, + to_device_key, + device_list_key, + groups_key, + un_partial_stated_rooms_key, + ) = keys + return cls( - await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + room_key=await RoomStreamToken.parse(store, room_key), + presence_key=int(presence_key), + typing_key=int(typing_key), + receipt_key=await MultiWriterStreamToken.parse(store, receipt_key), + account_data_key=int(account_data_key), + push_rules_key=int(push_rules_key), + to_device_key=int(to_device_key), + device_list_key=int(device_list_key), + groups_key=int(groups_key), + un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), ) except CancelledError: raise @@ -813,7 +922,7 @@ class StreamToken: await self.room_key.to_string(store), str(self.presence_key), str(self.typing_key), - str(self.receipt_key), + await self.receipt_key.to_string(store), str(self.account_data_key), str(self.push_rules_key), str(self.to_device_key), @@ -841,6 +950,11 @@ class StreamToken: StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) ) return new_token + elif key == StreamKeyType.RECEIPT: + new_token = self.copy_and_replace( + StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value) + ) + return new_token new_token = self.copy_and_replace(key, new_value) new_id = new_token.get_field(key) @@ -858,6 +972,10 @@ class StreamToken: def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: ... + @overload + def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken: + ... + @overload def get_field( self, @@ -866,7 +984,6 @@ class StreamToken: StreamKeyType.DEVICE_LIST, StreamKeyType.PRESENCE, StreamKeyType.PUSH_RULES, - StreamKeyType.RECEIPT, StreamKeyType.TO_DEVICE, StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, @@ -875,15 +992,21 @@ class StreamToken: ... @overload - def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: + def get_field( + self, key: StreamKeyType + ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: ... - def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: + def get_field( + self, key: StreamKeyType + ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: """Returns the stream ID for the given key.""" return getattr(self, key.value) -StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0) +StreamToken.START = StreamToken( + RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0 +) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index c888d1ff01..78646cb5dc 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -31,7 +31,12 @@ from synapse.appservice import ( from synapse.handlers.appservice import ApplicationServicesHandler from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType +from synapse.types import ( + JsonDict, + MultiWriterStreamToken, + RoomStreamToken, + StreamKeyType, +) from synapse.util import Clock from synapse.util.stringutils import random_string @@ -305,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.handler.notify_interested_services_ephemeral( - StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, + MultiWriterStreamToken(stream=580), + ["@fakerecipient:example.com"], ) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( interested_service, ephemeral=[event] @@ -333,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.handler.notify_interested_services_ephemeral( - StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, + MultiWriterStreamToken(stream=580), + ["@fakerecipient:example.com"], ) # This method will be called, but with an empty list of events self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( @@ -636,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.hs.get_application_service_handler()._notify_interested_services_ephemeral( services=[interested_appservice], stream_key=StreamKeyType.RECEIPT, - new_token=stream_token, + new_token=MultiWriterStreamToken(stream=stream_token), users=[self.exclusive_as_user], ) ) diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py new file mode 100644 index 0000000000..41876b36de --- /dev/null +++ b/tests/replication/test_sharded_receipts.py @@ -0,0 +1,243 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import ReceiptTypes +from synapse.rest import admin +from synapse.rest.client import login, receipts, room, sync +from synapse.server import HomeServer +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import StreamToken +from synapse.util import Clock + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import make_request + +logger = logging.getLogger(__name__) + + +class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks receipts sharding works""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + self.room_creator = self.hs.get_room_creation_handler() + self.store = hs.get_datastores().main + + def default_config(self) -> dict: + conf = super().default_config() + conf["stream_writers"] = {"receipts": ["worker1", "worker2"]} + conf["instance_map"] = { + "main": {"host": "testserv", "port": 8765}, + "worker1": {"host": "testserv", "port": 1001}, + "worker2": {"host": "testserv", "port": 1002}, + } + return conf + + def test_basic(self) -> None: + """Simple test to ensure that receipts can be sent on multiple + workers. + """ + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker1"}, + ) + worker1_site = self._hs_to_site[worker1] + + worker2 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker2"}, + ) + worker2_site = self._hs_to_site[worker2] + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Create a room + room_id = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room_id, user=self.other_user_id, tok=self.other_access_token + ) + + # First user sends a message, the other users sends a receipt. + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + channel = make_request( + reactor=self.reactor, + site=worker1_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", + access_token=access_token, + content={}, + ) + self.assertEqual(200, channel.code) + + # Now we do it again using the second worker + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + channel = make_request( + reactor=self.reactor, + site=worker2_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", + access_token=access_token, + content={}, + ) + self.assertEqual(200, channel.code) + + def test_vector_clock_token(self) -> None: + """Tests that using a stream token with a vector clock component works + correctly with basic /sync usage. + """ + + worker_hs1 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker1"}, + ) + worker1_site = self._hs_to_site[worker_hs1] + + worker_hs2 = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "worker2"}, + ) + worker2_site = self._hs_to_site[worker_hs2] + + sync_hs = self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "sync"}, + ) + sync_hs_site = self._hs_to_site[sync_hs] + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + store = self.hs.get_datastores().main + + room_id = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room_id, user=self.other_user_id, tok=self.other_access_token + ) + + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + first_event = response["event_id"] + + # Do an initial sync so that we're up to date. + channel = make_request( + self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token + ) + next_batch = channel.json_body["next_batch"] + + # We now gut wrench into the events stream MultiWriterIdGenerator on + # worker2 to mimic it getting stuck persisting a receipt. This ensures + # that when we send an event on worker1 we end up in a state where + # worker2 events stream position lags that on worker1, resulting in a + # receipts token with a non-empty instance map component. + # + # Worker2's receipts stream position will not advance until we call + # __aexit__ again. + worker_store2 = worker_hs2.get_datastores().main + assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator) + + actx = worker_store2._receipts_id_gen.get_next() + self.get_success(actx.__aenter__()) + + channel = make_request( + reactor=self.reactor, + site=worker1_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}", + access_token=access_token, + content={}, + ) + self.assertEqual(200, channel.code) + + # Assert that the current stream token has an instance map component, as + # we are trying to test vector clock tokens. + receipts_token = store.get_max_receipt_stream_id() + self.assertGreater(len(receipts_token.instance_map), 0) + + # Check that syncing still gets the new receipt, despite the gap in the + # stream IDs. + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + f"/sync?since={next_batch}", + access_token=access_token, + ) + + # We should only see the new event and nothing else + self.assertIn(room_id, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] + self.assertEqual(len(events), 1) + self.assertIn(first_event, events[0]["content"]) + + # Get the next batch and makes sure its a vector clock style token. + vector_clock_token = channel.json_body["next_batch"] + parsed_token = self.get_success( + StreamToken.from_string(store, vector_clock_token) + ) + self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1) + + # Now that we've got a vector clock token we finish the fake persisting + # a receipt we started above. + self.get_success(actx.__aexit__(None, None, None)) + + # Now try and send another receipts to the other worker. + response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) + second_event = response["event_id"] + + channel = make_request( + reactor=self.reactor, + site=worker2_site, + method="POST", + path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}", + access_token=access_token, + content={}, + ) + + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + f"/sync?since={vector_clock_token}", + access_token=access_token, + ) + + self.assertIn(room_id, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] + self.assertEqual(len(events), 1) + self.assertIn(second_event, events[0]["content"]) -- cgit 1.5.1 From 9407d5ba78d1e5275b5817ae9e6aedf7d1ca14f7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Oct 2023 13:01:36 -0400 Subject: Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505) This should use fewer allocations and improves type hints. --- changelog.d/16505.misc | 1 + synapse/handlers/deactivate_account.py | 4 +- synapse/handlers/sso.py | 5 +- synapse/storage/database.py | 31 +-- synapse/storage/databases/main/account_data.py | 18 +- synapse/storage/databases/main/appservice.py | 13 +- synapse/storage/databases/main/client_ips.py | 25 ++- synapse/storage/databases/main/devices.py | 70 +++--- synapse/storage/databases/main/e2e_room_keys.py | 49 ++-- synapse/storage/databases/main/event_federation.py | 18 +- .../databases/main/experimental_features.py | 15 +- synapse/storage/databases/main/keys.py | 35 +-- synapse/storage/databases/main/media_repository.py | 58 +++-- synapse/storage/databases/main/push_rule.py | 52 +++-- synapse/storage/databases/main/pusher.py | 20 +- synapse/storage/databases/main/registration.py | 60 +++-- synapse/storage/databases/main/relations.py | 15 +- synapse/storage/databases/main/room.py | 34 +-- synapse/storage/databases/main/roommember.py | 15 +- synapse/storage/databases/main/tags.py | 28 ++- synapse/storage/databases/main/ui_auth.py | 32 +-- synapse/storage/databases/state/store.py | 18 +- tests/handlers/test_stats.py | 14 +- tests/storage/databases/main/test_receipts.py | 20 +- tests/storage/test__base.py | 16 +- tests/storage/test_background_update.py | 35 +-- tests/storage/test_base.py | 4 +- tests/storage/test_client_ips.py | 250 ++++++++++----------- tests/storage/test_roommember.py | 40 ++-- tests/storage/test_state.py | 62 ++--- tests/storage/test_user_directory.py | 61 ++--- 31 files changed, 609 insertions(+), 509 deletions(-) create mode 100644 changelog.d/16505.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/16505.misc b/changelog.d/16505.misc new file mode 100644 index 0000000000..bd7cdd42af --- /dev/null +++ b/changelog.d/16505.misc @@ -0,0 +1 @@ +Reduce memory allocations. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 6a8f8f2fd1..370f4041fb 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -103,10 +103,10 @@ class DeactivateAccountHandler: # Attempt to unbind any known bound threepids to this account from identity # server(s). bound_threepids = await self.store.user_get_bound_threepids(user_id) - for threepid in bound_threepids: + for medium, address in bound_threepids: try: result = await self._identity_handler.try_unbind_threepid( - user_id, threepid["medium"], threepid["address"], id_server + user_id, medium, address, id_server ) except Exception: # Do we want this to be a fatal error or should we carry on? diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e9a544e754..62f2454f5d 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -1206,10 +1206,7 @@ class SsoHandler: # We have no guarantee that all the devices of that session are for the same # `user_id`. Hence, we have to iterate over the list of devices and log them out # one by one. - for device in devices: - user_id = device["user_id"] - device_id = device["device_id"] - + for user_id, device_id in devices: # If the user_id associated with that device/session is not the one we got # out of the `sub` claim, skip that device and show log an error. if expected_user_id is not None and user_id != expected_user_id: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 81f661160c..774d5c12f0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -606,13 +606,16 @@ class DatabasePool: If the background updates have not completed, wait 15 sec and check again. """ - updates = await self.simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", + updates = cast( + List[Tuple[str]], + await self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ), ) - background_update_names = [x["update_name"] for x in updates] + background_update_names = [x[0] for x in updates] for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): if update_name not in background_update_names: @@ -1804,9 +1807,9 @@ class DatabasePool: keyvalues: Optional[Dict[str, Any]], retcols: Collection[str], desc: str = "simple_select_list", - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. + more rows, returning the result as a list of tuples. Args: table: the table name @@ -1817,8 +1820,7 @@ class DatabasePool: desc: description of the transaction, for logging and metrics Returns: - A list of dictionaries, one per result row, each a mapping between the - column names from `retcols` and that column's value for the row. + A list of tuples, one per result row, each the retcolumn's value for the row. """ return await self.runInteraction( desc, @@ -1836,9 +1838,9 @@ class DatabasePool: table: str, keyvalues: Optional[Dict[str, Any]], retcols: Iterable[str], - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. + more rows, returning the result as a list of tuples. Args: txn: Transaction object @@ -1849,8 +1851,7 @@ class DatabasePool: retcols: the names of the columns to return Returns: - A list of dictionaries, one per result row, each a mapping between the - column names from `retcols` and that column's value for the row. + A list of tuples, one per result row, each the retcolumn's value for the row. """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1863,7 +1864,7 @@ class DatabasePool: sql = "SELECT %s FROM %s" % (", ".join(retcols), table) txn.execute(sql) - return cls.cursor_to_dict(txn) + return txn.fetchall() async def simple_select_many_batch( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 84ef8136c2..d7482a1f4e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_room_txn( txn: LoggingTransaction, - ) -> Dict[str, JsonDict]: - rows = self.db_pool.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id, "room_id": room_id}, - ["account_data_type", "content"], + ) -> Dict[str, JsonMapping]: + rows = cast( + List[Tuple[str, str]], + self.db_pool.simple_select_list_txn( + txn, + table="room_account_data", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=["account_data_type", "content"], + ), ) return { - row["account_data_type"]: db_to_json(row["content"]) for row in rows + account_data_type: db_to_json(content) + for account_data_type, content in rows } return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 073a99cd84..fa7d1c469a 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore( Returns: A list of ApplicationServices, which may be empty. """ - results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state.value}, ["as_id"] + results = cast( + List[Tuple[str]], + await self.db_pool.simple_select_list( + table="application_services_state", + keyvalues={"state": state.value}, + retcols=("as_id",), + ), ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() services = [] - for res in results: + for (as_id,) in results: for service in as_list: - if service.id == res["as_id"]: + if service.id == as_id: services.append(service) return services diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 8be1511859..c006129625 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke if device_id is not None: keyvalues["device_id"] = device_id - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + res = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ) return { - (d["user_id"], d["device_id"]): DeviceLastConnectionInfo( - user_id=d["user_id"], - device_id=d["device_id"], - ip=d["ip"], - user_agent=d["user_agent"], - last_seen=d["last_seen"], + (user_id, device_id): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip=ip, + user_agent=user_agent, + last_seen=last_seen, ) - for d in res + for user_id, ip, user_agent, device_id, last_seen in res } async def _get_user_ip_and_agents_from_database( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fc23d18eba..0b75f6763a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): allow_none=True, ) - async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: + async def get_devices_by_user( + self, user_id: str + ) -> Dict[str, Dict[str, Optional[str]]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. @@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_id: Returns: A mapping from device_id to a dict containing "device_id", "user_id" - and "display_name" for each device. + and "display_name" for each device. Display name may be null. """ - devices = await self.db_pool.simple_select_list( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_devices_by_user", + devices = cast( + List[Tuple[str, str, Optional[str]]], + await self.db_pool.simple_select_list( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user", + ), ) - return {d["device_id"]: d for d in devices} + return { + d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]} + for d in devices + } async def get_devices_by_auth_provider_session_id( self, auth_provider_id: str, auth_provider_session_id: str - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, str]]: """Retrieve the list of devices associated with a SSO IdP session ID. Args: @@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): Returns: A list of dicts containing the device_id and the user_id of each device """ - return await self.db_pool.simple_select_list( - table="device_auth_providers", - keyvalues={ - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - retcols=("user_id", "device_id"), - desc="get_devices_by_auth_provider_session_id", + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ), ) @trace @@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): async def get_cached_devices_for_user( self, user_id: str ) -> Mapping[str, JsonMapping]: - devices = await self.db_pool.simple_select_list( - table="device_lists_remote_cache", - keyvalues={"user_id": user_id}, - retcols=("device_id", "content"), - desc="get_cached_devices_for_user", + devices = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="device_lists_remote_cache", + keyvalues={"user_id": user_id}, + retcols=("device_id", "content"), + desc="get_cached_devices_for_user", + ), ) - return { - device["device_id"]: db_to_json(device["content"]) for device in devices - } + return {device[0]: db_to_json(device[1]) for device in devices} def get_cached_device_list_changes( self, @@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The IDs of users whose device lists need resync. """ if user_ids: - row_tuples = cast( + rows = cast( List[Tuple[str]], await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", @@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): desc="get_user_ids_requiring_device_list_resync_with_iterable", ), ) - - return {row[0] for row in row_tuples} else: rows = cast( - List[Dict[str, str]], + List[Tuple[str]], await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, @@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ), ) - return {row["user_id"] for row in rows} + return {row[0] for row in rows} async def mark_remote_users_device_caches_as_stale( self, user_ids: StrCollection diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index aac4cfb054..ad904a26a6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast from typing_extensions import Literal, TypedDict @@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): if session_id: keyvalues["session_id"] = session_id - rows = await self.db_pool.simple_select_list( - table="e2e_room_keys", - keyvalues=keyvalues, - retcols=( - "user_id", - "room_id", - "session_id", - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", + rows = cast( + List[Tuple[str, str, int, int, int, str]], + await self.db_pool.simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", ), - desc="get_e2e_room_keys", ) sessions: Dict[ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] ] = {"rooms": {}} - for row in rows: - room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) - room_entry["sessions"][row["session_id"]] = { - "first_message_index": row["first_message_index"], - "forwarded_count": row["forwarded_count"], + for ( + room_id, + session_id, + first_message_index, + forwarded_count, + is_verified, + session_data, + ) in rows: + room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}}) + room_entry["sessions"][session_id] = { + "first_message_index": first_message_index, + "forwarded_count": forwarded_count, # is_verified must be returned to the client as a boolean - "is_verified": bool(row["is_verified"]), - "session_data": db_to_json(row["session_data"]), + "is_verified": bool(is_verified), + "session_data": db_to_json(session_data), } return sessions diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4f80ce75cc..f1b0991503 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # keeping only the forward extremities (i.e. the events not referenced # by other events in the queue). We do this so that we can always # backpaginate in all the events we have dropped. - rows = await self.db_pool.simple_select_list( - table="federation_inbound_events_staging", - keyvalues={"room_id": room_id}, - retcols=("event_id", "event_json"), - desc="prune_staged_events_in_room_fetch", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcols=("event_id", "event_json"), + desc="prune_staged_events_in_room_fetch", + ), ) # Find the set of events referenced by those in the queue, as well as # collecting all the event IDs in the queue. referenced_events: Set[str] = set() seen_events: Set[str] = set() - for row in rows: - event_id = row["event_id"] + for event_id, event_json in rows: seen_events.add(event_id) - event_d = db_to_json(row["event_json"]) + event_d = db_to_json(event_json) # We don't bother parsing the dicts into full blown event objects, # as that is needlessly expensive. diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py index 654f924019..60621edeef 100644 --- a/synapse/storage/databases/main/experimental_features.py +++ b/synapse/storage/databases/main/experimental_features.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, FrozenSet +from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main import CacheInvalidationWorkerStore @@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): Returns: the features currently enabled for the user """ - enabled = await self.db_pool.simple_select_list( - "per_user_experimental_features", - {"user_id": user_id, "enabled": True}, - ["feature"], + enabled = cast( + List[Tuple[str]], + await self.db_pool.simple_select_list( + table="per_user_experimental_features", + keyvalues={"user_id": user_id, "enabled": True}, + retcols=("feature",), + ), ) - return frozenset(feature["feature"] for feature in enabled) + return frozenset(feature[0] for feature in enabled) async def set_features_for_user( self, diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index ea797864b9..ce88772f9e 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore): If we have multiple entries for a given key ID, returns the most recent. """ - rows = await self.db_pool.simple_select_list( - table="server_keys_json", - keyvalues={"server_name": server_name}, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", + rows = cast( + List[Tuple[str, str, int, int, Union[bytes, memoryview]]], + await self.db_pool.simple_select_list( + table="server_keys_json", + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ), - desc="get_server_keys_json_for_remote", ) if not rows: @@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore): # We sort the rows by ts_added_ms so that the most recently added entry # will stomp over older entries in the dictionary. - rows.sort(key=lambda r: r["ts_added_ms"]) + rows.sort(key=lambda r: r[2]) return { - row["key_id"]: FetchKeyResultForRemote( + key_id: FetchKeyResultForRemote( # Cast to bytes since postgresql returns a memoryview. - key_json=bytes(row["key_json"]), - valid_until_ts=row["ts_valid_until_ms"], - added_ts=row["ts_added_ms"], + key_json=bytes(key_json), + valid_until_ts=ts_valid_until_ms, + added_ts=ts_added_ms, ) - for row in rows + for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows } diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 2e6b176bd2..f82140b2e8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]: - rows = await self.db_pool.simple_select_list( - "local_media_repository_thumbnails", - {"media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", + rows = cast( + List[Tuple[int, int, str, str, int]], + await self.db_pool.simple_select_list( + "local_media_repository_thumbnails", + {"media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_local_media_thumbnails", ), - desc="get_local_media_thumbnails", ) return [ ThumbnailInfo( - width=row["thumbnail_width"], - height=row["thumbnail_height"], - method=row["thumbnail_method"], - type=row["thumbnail_type"], - length=row["thumbnail_length"], + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] ) for row in rows ] @@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_thumbnails( self, origin: str, media_id: str ) -> List[ThumbnailInfo]: - rows = await self.db_pool.simple_select_list( - "remote_media_cache_thumbnails", - {"media_origin": origin, "media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", + rows = cast( + List[Tuple[int, int, str, str, int]], + await self.db_pool.simple_select_list( + "remote_media_cache_thumbnails", + {"media_origin": origin, "media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_remote_media_thumbnails", ), - desc="get_remote_media_thumbnails", ) return [ ThumbnailInfo( - width=row["thumbnail_width"], - height=row["thumbnail_height"], - method=row["thumbnail_method"], - type=row["thumbnail_type"], - length=row["thumbnail_length"], + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] ) for row in rows ] diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index f5356e7f80..22025eca56 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -179,46 +179,44 @@ class PushRulesWorkerStore( @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: - rows = await self.db_pool.simple_select_list( - table="push_rules", - keyvalues={"user_name": user_id}, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", + rows = cast( + List[Tuple[str, int, int, str, str]], + await self.db_pool.simple_select_list( + table="push_rules", + keyvalues={"user_name": user_id}, + retcols=( + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="get_push_rules_for_user", ), - desc="get_push_rules_for_user", ) - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + # Sort by highest priority_class, then highest priority. + rows.sort(key=lambda row: (-int(row[1]), -int(row[2]))) enabled_map = await self.get_push_rules_enabled_for_user(user_id) return _load_rules( - [ - ( - row["rule_id"], - row["priority_class"], - row["conditions"], - row["actions"], - ) - for row in rows - ], + [(row[0], row[1], row[3], row[4]) for row in rows], enabled_map, self.hs.config.experimental, ) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: - results = await self.db_pool.simple_select_list( - table="push_rules_enable", - keyvalues={"user_name": user_id}, - retcols=("rule_id", "enabled"), - desc="get_push_rules_enabled_for_user", + results = cast( + List[Tuple[str, Optional[Union[int, bool]]]], + await self.db_pool.simple_select_list( + table="push_rules_enable", + keyvalues={"user_name": user_id}, + retcols=("rule_id", "enabled"), + desc="get_push_rules_enabled_for_user", + ), ) - return {r["rule_id"]: bool(r["enabled"]) for r in results} + return {r[0]: bool(r[1]) for r in results} async def have_push_rules_changed_for_user( self, user_id: str, last_id: int diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c7eb7fc478..a6a1671bd6 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore): async def get_throttle_params_by_room( self, pusher_id: int ) -> Dict[str, ThrottleParams]: - res = await self.db_pool.simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", + res = cast( + List[Tuple[str, Optional[int], Optional[int]]], + await self.db_pool.simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ), ) params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = ThrottleParams( - row["last_sent_ts"], - row["throttle_ms"], + for room_id, last_sent_ts, throttle_ms in res: + params_by_room[room_id] = ThrottleParams( + last_sent_ts or 0, throttle_ms or 0 ) return params_by_room diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 9e8643ae4d..b0ef7be155 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: Tuples of (auth_provider, external_id) """ - res = await self.db_pool.simple_select_list( - table="user_external_ids", - keyvalues={"user_id": mxid}, - retcols=("auth_provider", "external_id"), - desc="get_external_ids_by_user", + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="user_external_ids", + keyvalues={"user_id": mxid}, + retcols=("auth_provider", "external_id"), + desc="get_external_ids_by_user", + ), ) - return [(r["auth_provider"], r["external_id"]) for r in res] async def count_all_users(self) -> int: """Counts all users registered on the homeserver.""" @@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: - results = await self.db_pool.simple_select_list( - "user_threepids", - keyvalues={"user_id": user_id}, - retcols=["medium", "address", "validated_at", "added_at"], - desc="user_get_threepids", + results = cast( + List[Tuple[str, str, int, int]], + await self.db_pool.simple_select_list( + "user_threepids", + keyvalues={"user_id": user_id}, + retcols=["medium", "address", "validated_at", "added_at"], + desc="user_get_threepids", + ), ) - return [ThreepidResult(**r) for r in results] + return [ + ThreepidResult( + medium=r[0], + address=r[1], + validated_at=r[2], + added_at=r[3], + ) + for r in results + ] async def user_delete_threepid( self, user_id: str, medium: str, address: str @@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="add_user_bound_threepid", ) - async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: + async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. @@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): user_id: The ID of the user to retrieve threepids for Returns: - List of dictionaries containing the following keys: - medium (str): The medium of the threepid (e.g "email") - address (str): The address of the threepid (e.g "bob@example.com") - """ - return await self.db_pool.simple_select_list( - table="user_threepid_id_server", - keyvalues={"user_id": user_id}, - retcols=["medium", "address"], - desc="user_get_bound_threepids", + List of tuples of two strings: + medium: The medium of the threepid (e.g "email") + address: The address of the threepid (e.g "bob@example.com") + """ + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ), ) async def remove_user_bound_threepid( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ce7bfd5146..419b2c7a22 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore): def get_all_relation_ids_for_event_txn( txn: LoggingTransaction, ) -> List[str]: - rows = self.db_pool.simple_select_list_txn( - txn=txn, - table="event_relations", - keyvalues={"relates_to_id": event_id}, - retcols=["event_id"], + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_list_txn( + txn=txn, + table="event_relations", + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ), ) - return [row["event_id"] for row in rows] + return [row[0] for row in rows] return await self.db_pool.runInteraction( desc="get_all_relation_ids_for_event", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 9d24d2c347..3e8fcf1975 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ room_servers: Dict[str, PartialStateResyncInfo] = {} - rows = await self.db_pool.simple_select_list( - table="partial_state_rooms", - keyvalues={}, - retcols=("room_id", "joined_via"), - desc="get_server_which_served_partial_join", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ), ) - for row in rows: - room_id = row["room_id"] - joined_via = row["joined_via"] + for room_id, joined_via in rows: room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) - rows = await self.db_pool.simple_select_list( - "partial_state_rooms_servers", - keyvalues=None, - retcols=("room_id", "server_name"), - desc="get_partial_state_rooms", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + "partial_state_rooms_servers", + keyvalues=None, + retcols=("room_id", "server_name"), + desc="get_partial_state_rooms", + ), ) - for row in rows: - room_id = row["room_id"] - server_name = row["server_name"] + for room_id, server_name in rows: entry = room_servers.get(room_id) if entry is None: # There is a foreign key constraint which enforces that every room_id in diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3a87eba430..a1627dffb7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): for fully-joined rooms. """ - rows = await self.db_pool.simple_select_list( - "current_state_events", - keyvalues={"room_id": room_id}, - retcols=("event_id", "membership"), - desc="has_completed_background_updates", + rows = cast( + List[Tuple[str, Optional[str]]], + await self.db_pool.simple_select_list( + "current_state_events", + keyvalues={"room_id": room_id}, + retcols=("event_id", "membership"), + desc="has_completed_background_updates", + ), ) - return {row["event_id"]: row["membership"] for row in rows} + return dict(rows) # TODO This returns a mutable object, which is generally confusing when using a cache. @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable] diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 61403a98cf..7deda7790e 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore): tag content. """ - rows = await self.db_pool.simple_select_list( - "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + rows = cast( + List[Tuple[str, str, str]], + await self.db_pool.simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + ), ) tags_by_room: Dict[str, Dict[str, JsonDict]] = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) + for room_id, tag, content in rows: + room_tags = tags_by_room.setdefault(room_id, {}) + room_tags[tag] = db_to_json(content) return tags_by_room async def get_all_updated_tags( @@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A mapping of tags to tag content. """ - rows = await self.db_pool.simple_select_list( - table="room_tags", - keyvalues={"user_id": user_id, "room_id": room_id}, - retcols=("tag", "content"), - desc="get_tags_for_room", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=("tag", "content"), + desc="get_tags_for_room", + ), ) - return {row["tag"]: db_to_json(row["content"]) for row in rows} + return {tag: db_to_json(content) for tag, content in rows} async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 919c66f553..8ab7c42c4a 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore): that auth-type. """ results = {} - for row in await self.db_pool.simple_select_list( - table="ui_auth_sessions_credentials", - keyvalues={"session_id": session_id}, - retcols=("stage_type", "result"), - desc="get_completed_ui_auth_stages", - ): - results[row["stage_type"]] = db_to_json(row["result"]) + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ), + ) + for stage_type, result in rows: + results[stage_type] = db_to_json(result) return results @@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore): Returns: List of user_agent/ip pairs """ - rows = await self.db_pool.simple_select_list( - table="ui_auth_sessions_ips", - keyvalues={"session_id": session_id}, - retcols=("user_agent", "ip"), - desc="get_user_agents_ips_to_ui_auth_session", + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ), ) - return [(row["user_agent"], row["ip"]) for row in rows] async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 09d2a8c5b3..182e429174 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self.db_pool.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), + delta_ids = cast( + List[Tuple[str, str, str]], + self.db_pool.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ), ) return _GetStateGroupDelta( prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + { + (event_type, state_key): event_id + for event_type, state_key, event_id in delta_ids + }, ) return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index d11ded6c5b..76c56d5434 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, cast from twisted.test.proto_helpers import MemoryReactor @@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - async def get_all_room_state(self) -> List[Dict[str, Any]]: - return await self.store.db_pool.simple_select_list( - "room_stats_state", None, retcols=("name", "topic", "canonical_alias") + async def get_all_room_state(self) -> List[Optional[str]]: + rows = cast( + List[Tuple[Optional[str]]], + await self.store.db_pool.simple_select_list( + "room_stats_state", None, retcols=("topic",) + ), ) + return [r[0] for r in rows] def _get_current_stats( self, stats_type: str, stat_id: str @@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 1) - self.assertEqual(r[0]["topic"], "foo") + self.assertEqual(r[0], "foo") def test_create_user(self) -> None: """ diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index 71db47405e..98b01086bc 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): if expected_row is not None: columns += expected_row.keys() - rows = self.get_success( + row_tuples = self.get_success( self.store.db_pool.simple_select_list( table=table, keyvalues={ @@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): if expected_row is not None: self.assertEqual( - len(rows), + len(row_tuples), 1, f"Background update did not leave behind latest receipt in {table}", ) self.assertEqual( - rows[0], - { - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - **expected_row, - }, + row_tuples[0], + ( + room_id, + receipt_type, + user_id, + *expected_row.values(), + ), ) else: self.assertEqual( - len(rows), + len(row_tuples), 0, f"Background update did not remove all duplicate receipts from {table}", ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8bbf936ae9..8cbc974ac4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,7 +14,7 @@ # limitations under the License. import secrets -from typing import Generator, Tuple +from typing import Generator, List, Tuple, cast from twisted.test.proto_helpers import MemoryReactor @@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): ) def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] - ) + yield from cast( + List[Tuple[int, str, str]], + self.get_success( + self.storage.db_pool.simple_select_list( + self.table_name, None, ["id, username, value"] + ) + ), ) - for i in res: - yield (i["id"], i["username"], i["value"]) - def test_upsert_many(self) -> None: """ Upsert_many will perform the upsert operation across a batch of data. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index abf7d0564d..3f5bfa09d4 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple, cast from unittest.mock import AsyncMock, Mock import yaml @@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): self.wait_for_background_updates() # Check the correct values are in the new table. - rows = self.get_success( - self.store.db_pool.simple_select_list( - table="test_constraint", - keyvalues={}, - retcols=("a", "b"), - ) + rows = cast( + List[Tuple[int, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ), ) - self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + self.assertCountEqual(rows, [(1, 1), (3, 3)]) # And check that invalid rows get correctly rejected. self.get_failure( @@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): self.wait_for_background_updates() # Check the correct values are in the new table. - rows = self.get_success( - self.store.db_pool.simple_select_list( - table="test_constraint", - keyvalues={}, - retcols=("a", "b"), - ) + rows = cast( + List[Tuple[int, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ), ) - self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + self.assertCountEqual(rows, [(1, 1), (3, 3)]) # And check that invalid rows get correctly rejected. self.get_failure( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 256d28e4c9..e4a52c301e 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 3 - self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) + self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)] self.mock_txn.description = (("colA", None, None, None, None, None, None),) ret = yield defer.ensureDeferred( @@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([(1,), (2,), (3,)], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 0c054a598f..8e4393d843 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, cast from unittest.mock import AsyncMock from parameterized import parameterized @@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) self.pump(0) - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual( - result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": None, - "last_seen": 12345678000, - } - ], + result, [("access_token", "ip", "user_agent", None, 12345678000)] ) # Add another & trigger the storage loop @@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) self.pump(0) - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) # Only one result, has been upserted. self.assertEqual( - result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": None, - "last_seen": 12345878000, - } - ], + result, [("access_token", "ip", "user_agent", None, 12345878000)] ) @parameterized.expand([(False,), (True,)]) @@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) else: # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="devices", - keyvalues={}, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + db_result = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=( + "user_id", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + ), ), ) - self.assertEqual( - db_result, - [ - { - "user_id": user_id, - "device_id": device_id, - "ip": None, - "user_agent": None, - "last_seen": None, - }, - ], - ) + self.assertEqual(db_result, [(user_id, None, None, device_id, None)]) result = self.get_success( self.store.get_last_client_ip_by_device(user_id, device_id) @@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="devices", - keyvalues={}, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + db_result = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ), ) self.assertCountEqual( db_result, [ - { - "user_id": user_id, - "device_id": device_id_1, - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - { - "user_id": user_id, - "device_id": device_id_2, - "ip": "ip_2", - "user_agent": "user_agent_2", - "last_seen": 12345678000, - }, + (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000), + (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000), ], ) @@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={}, - retcols=("access_token", "ip", "user_agent", "last_seen"), + db_result = cast( + List[Tuple[str, str, str, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=("access_token", "ip", "user_agent", "last_seen"), + ), ), ) self.assertEqual( db_result, [ - { - "access_token": "access_token", - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - { - "access_token": "access_token", - "ip": "ip_2", - "user_agent": "user_agent_2", - "last_seen": 12345678000, - }, + ("access_token", "ip_1", "user_agent_1", 12345678000), + ("access_token", "ip_2", "user_agent_2", 12345678000), ], ) @@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) # We should see that in the DB - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual( result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": device_id, - "last_seen": 0, - } - ], + [("access_token", "ip", "user_agent", device_id, 0)], ) # Now advance by a couple of months self.reactor.advance(60 * 24 * 60 * 60) # We should get no results. - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual(result, []) @@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) # We should see that in the DB - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) # ensure user1 is filtered out - self.assertEqual( - result, - [ - { - "access_token": access_token2, - "ip": "ip", - "user_agent": "user_agent", - "device_id": device_id2, - "last_seen": 0, - } - ], - ) + self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)]) class ClientIpAuthTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index f4c4661aaf..36fcab06b5 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple, cast + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import Membership @@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - res = self.get_success( - self.store.db_pool.simple_select_list( - "room_memberships", - {"user_id": "@alice:test"}, - ["display_name", "event_id"], - ) + res = cast( + List[Tuple[Optional[str], str]], + self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ), ) # Check that we only got one result back self.assertEqual(len(res), 1) # Check that alice's display name is "alice" - self.assertEqual(res[0]["display_name"], "alice") + self.assertEqual(res[0][0], "alice") # Grab the event_id to use later - event_id = res[0]["event_id"] + event_id = res[0][1] # Create a profile with the offending null byte in the display name new_profile = {"displayname": "ali\u0000ce"} @@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): tok=self.t_alice, ) - res2 = self.get_success( - self.store.db_pool.simple_select_list( - "room_memberships", - {"user_id": "@alice:test"}, - ["display_name", "event_id"], - ) + res2 = cast( + List[Tuple[Optional[str], str]], + self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ), ) # Check that we only have two results self.assertEqual(len(res2), 2) # Filter out the previous event using the event_id we grabbed above - row = [row for row in res2 if row["event_id"] != event_id] + row = [row for row in res2 if row[1] != event_id] # Check that alice's display name is now None - self.assertEqual(row[0]["display_name"], None) + self.assertIsNone(row[0][0]) def test_room_is_locally_forgotten(self) -> None: """Test that when the last local user has forgotten a room it is known as forgotten.""" diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b9446c36c..2715c73f16 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import List, Tuple, cast from immutabledict import immutabledict @@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase): ) # check that only state events are in state_groups, and all state events are in state_groups - res = self.get_success( - self.store.db_pool.simple_select_list( - table="state_groups", - keyvalues=None, - retcols=("event_id",), - ) + res = cast( + List[Tuple[str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ), ) events = [] for result in res: - self.assertNotIn(event3.event_id, result) - events.append(result.get("event_id")) + self.assertNotIn(event3.event_id, result) # XXX + events.append(result[0]) for event, _ in processed_events_and_context: if event.is_state(): @@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase): # has an entry and prev event in state_group_edges for event, context in processed_events_and_context: if event.is_state(): - state = self.get_success( - self.store.db_pool.simple_select_list( - table="state_groups_state", - keyvalues={"state_group": context.state_group_after_event}, - retcols=("type", "state_key"), - ) - ) - self.assertEqual(event.type, state[0].get("type")) - self.assertEqual(event.state_key, state[0].get("state_key")) - - groups = self.get_success( - self.store.db_pool.simple_select_list( - table="state_group_edges", - keyvalues={"state_group": str(context.state_group_after_event)}, - retcols=("*",), - ) + state = cast( + List[Tuple[str, str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ), ) - self.assertEqual( - context.state_group_before_event, groups[0].get("prev_state_group") + self.assertEqual(event.type, state[0][0]) + self.assertEqual(event.state_key, state[0][1]) + + groups = cast( + List[Tuple[str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={ + "state_group": str(context.state_group_after_event) + }, + retcols=("prev_state_group",), + ) + ), ) + self.assertEqual(context.state_group_before_event, groups[0][0]) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 8c72aa1722..822c41dd9f 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, cast from unittest import mock from unittest.mock import Mock, patch @@ -62,14 +62,13 @@ class GetUserDirectoryTables: Returns a list of tuples (user_id, room_id) where room_id is public and contains the user with the given id. """ - r = await self.store.db_pool.simple_select_list( - "users_in_public_rooms", None, ("user_id", "room_id") + r = cast( + List[Tuple[str, str]], + await self.store.db_pool.simple_select_list( + "users_in_public_rooms", None, ("user_id", "room_id") + ), ) - - retval = set() - for i in r: - retval.add((i["user_id"], i["room_id"])) - return retval + return set(r) async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: """Fetch the entire `users_who_share_private_rooms` table. @@ -78,27 +77,30 @@ class GetUserDirectoryTables: to the rows of `users_who_share_private_rooms`. """ - rows = await self.store.db_pool.simple_select_list( - "users_who_share_private_rooms", - None, - ["user_id", "other_user_id", "room_id"], + rows = cast( + List[Tuple[str, str, str]], + await self.store.db_pool.simple_select_list( + "users_who_share_private_rooms", + None, + ["user_id", "other_user_id", "room_id"], + ), ) - rv = set() - for row in rows: - rv.add((row["user_id"], row["other_user_id"], row["room_id"])) - return rv + return set(rows) async def get_users_in_user_directory(self) -> Set[str]: """Fetch the set of users in the `user_directory` table. This is useful when checking we've correctly excluded users from the directory. """ - result = await self.store.db_pool.simple_select_list( - "user_directory", - None, - ["user_id"], + result = cast( + List[Tuple[str]], + await self.store.db_pool.simple_select_list( + "user_directory", + None, + ["user_id"], + ), ) - return {row["user_id"] for row in result} + return {row[0] for row in result} async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]: """Fetch users and their profiles from the `user_directory` table. @@ -107,16 +109,17 @@ class GetUserDirectoryTables: It's almost the entire contents of the `user_directory` table: the only thing missing is an unused room_id column. """ - rows = await self.store.db_pool.simple_select_list( - "user_directory", - None, - ("user_id", "display_name", "avatar_url"), + rows = cast( + List[Tuple[str, Optional[str], Optional[str]]], + await self.store.db_pool.simple_select_list( + "user_directory", + None, + ("user_id", "display_name", "avatar_url"), + ), ) return { - row["user_id"]: ProfileInfo( - display_name=row["display_name"], avatar_url=row["avatar_url"] - ) - for row in rows + user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url) + for user_id, display_name, avatar_url in rows } async def get_tables( -- cgit 1.5.1 From 679c691f6f7c4f7901e6d075a645a8ade20f44d5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Oct 2023 15:12:28 -0400 Subject: Remove more usages of cursor_to_dict. (#16551) Mostly to improve type safety. --- changelog.d/16551.misc | 1 + synapse/handlers/identity.py | 18 ++++---- synapse/handlers/ui_auth/checkers.py | 6 +-- synapse/media/media_repository.py | 5 +-- synapse/rest/admin/federation.py | 14 +++++- synapse/rest/admin/rooms.py | 12 ++++- synapse/rest/admin/statistics.py | 13 +++++- synapse/storage/database.py | 30 ++----------- synapse/storage/databases/main/censor_events.py | 2 +- synapse/storage/databases/main/devices.py | 3 +- synapse/storage/databases/main/end_to_end_keys.py | 1 - .../storage/databases/main/events_bg_updates.py | 7 +-- .../databases/main/events_forward_extremities.py | 15 ++++--- synapse/storage/databases/main/media_repository.py | 19 ++++---- synapse/storage/databases/main/registration.py | 43 ++++++++++++------ synapse/storage/databases/main/roommember.py | 4 +- synapse/storage/databases/main/search.py | 52 +++++++++++++--------- synapse/storage/databases/main/stats.py | 15 ++++--- synapse/storage/databases/main/stream.py | 3 +- synapse/storage/databases/main/transactions.py | 28 ++++++++++-- synapse/storage/databases/main/user_directory.py | 14 +++--- synapse/storage/databases/state/bg_updates.py | 1 - tests/federation/test_federation_catch_up.py | 1 - tests/storage/test_background_update.py | 16 +++---- tests/storage/test_profile.py | 2 +- tests/storage/test_user_filters.py | 2 +- 26 files changed, 193 insertions(+), 134 deletions(-) create mode 100644 changelog.d/16551.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/16551.misc b/changelog.d/16551.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16551.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 472879c964..c041b67993 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -19,6 +19,8 @@ import logging import urllib.parse from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple +import attr + from synapse.api.errors import ( CodeMessageException, Codes, @@ -357,9 +359,9 @@ class IdentityHandler: # Check to see if a session already exists and that it is not yet # marked as validated - if session and session.get("validated_at") is None: - session_id = session["session_id"] - last_send_attempt = session["last_send_attempt"] + if session and session.validated_at is None: + session_id = session.session_id + last_send_attempt = session.last_send_attempt # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: @@ -480,7 +482,6 @@ class IdentityHandler: # We don't actually know which medium this 3PID is. Thus we first assume it's email, # and if validation fails we try msisdn - validation_session = None # Try to validate as email if self.hs.config.email.can_verify_email: @@ -488,19 +489,18 @@ class IdentityHandler: validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True ) - - if validation_session: - return validation_session + if validation_session: + return attr.asdict(validation_session) # Try to validate as msisdn if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server - validation_session = await self.threepid_from_creds( + return await self.threepid_from_creds( self.hs.config.registration.account_threepid_delegate_msisdn, threepid_creds, ) - return validation_session + return None async def proxy_msisdn_submit_token( self, id_server: str, client_secret: str, sid: str, token: str diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 78a75bfed6..ab8f7610e9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker: if row: threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], + "medium": row.medium, + "address": row.address, + "validated_at": row.validated_at, } # Valid threepid returned, delete from the db diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 7fd46901f7..72b0f1c5de 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -949,10 +949,7 @@ class MediaRepository: deleted = 0 - for media in old_media: - origin = media["media_origin"] - media_id = media["media_id"] - file_id = media["filesystem_id"] + for origin, media_id, file_id in old_media: key = (origin, media_id) logger.info("Deleting: %r", key) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 8a617af599..a6ce787da1 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet): destinations, total = await self._store.get_destinations_paginate( start, limit, destination, order_by, direction ) - response = {"destinations": destinations, "total": total} + response = { + "destinations": [ + { + "destination": r[0], + "retry_last_ts": r[1], + "retry_interval": r[2], + "failure_ts": r[3], + "last_successful_stream_ordering": r[4], + } + for r in destinations + ], + "total": total, + } if (start + limit) < total: response["next_token"] = str(start + len(destinations)) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 436718c8b2..2d4da38db9 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return HTTPStatus.OK, {"count": len(extremities), "results": extremities} + result = [ + { + "event_id": ex[0], + "state_group": ex[1], + "depth": ex[2], + "received_ts": ex[3], + } + for ex in extremities + ] + + return HTTPStatus.OK, {"count": len(extremities), "results": result} class RoomEventContextServlet(RestServlet): diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 19780e4b4c..75d8a37ccf 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet): users_media, total = await self.store.get_users_media_usage_paginate( start, limit, from_ts, until_ts, order_by, direction, search_term ) - ret = {"users": users_media, "total": total} + ret = { + "users": [ + { + "user_id": r[0], + "displayname": r[1], + "media_count": r[2], + "media_length": r[3], + } + for r in users_media + ], + "total": total, + } if (start + limit) < total: ret["next_token"] = start + len(users_media) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 774d5c12f0..b1ece63845 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -35,7 +35,6 @@ from typing import ( Tuple, Type, TypeVar, - Union, cast, overload, ) @@ -1047,43 +1046,20 @@ class DatabasePool: results = [dict(zip(col_headers, row)) for row in cursor] return results - @overload - async def execute( - self, desc: str, decoder: Literal[None], query: str, *args: Any - ) -> List[Tuple[Any, ...]]: - ... - - @overload - async def execute( - self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any - ) -> R: - ... - - async def execute( - self, - desc: str, - decoder: Optional[Callable[[Cursor], R]], - query: str, - *args: Any, - ) -> Union[List[Tuple[Any, ...]], R]: + async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: """Runs a single query for a result set. Args: desc: description of the transaction, for logging and metrics - decoder - The function which can resolve the cursor results to - something meaningful. query - The query string to execute *args - Query args. Returns: The result of decoder(results) """ - def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]: + def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]: txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() + return txn.fetchall() return await self.runInteraction(desc, interaction) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 58177ecec1..711fdddd4e 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase """ rows = await self.db_pool.execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 + "_censor_redactions_fetch", sql, before_ts, 100 ) updates = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 0b75f6763a..49edbb9e06 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): rows = await self.db_pool.execute( "get_all_devices_changed", - None, sql, from_key, to_key, @@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): WHERE from_user_id = ? AND stream_id > ? """ rows = await self.db_pool.execute( - "get_users_whose_signatures_changed", None, sql, user_id, from_key + "get_users_whose_signatures_changed", sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} else: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f13d776b0d..f70f95eeba 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ rows = await self.db_pool.execute( "get_e2e_device_keys_for_federation_query_check", - None, sql, now_stream_id, user_id, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index c5fce1c82b..0061805150 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the # indexes on it. - # We need to pass execute a dummy function to handle the txn's result otherwise - # it tries to call fetchall() on it and fails because there's no result to fetch. - await self.db_pool.execute( + await self.db_pool.runInteraction( "background_analyze_new_stream_ordering_column", - lambda txn: None, - "ANALYZE events(stream_ordering2)", + lambda txn: txn.execute("ANALYZE events(stream_ordering2)"), ) await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index f851bff604..0ba84b1469 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List +from typing import List, Optional, Tuple, cast from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction @@ -91,12 +91,17 @@ class EventForwardExtremitiesStore( async def get_forward_extremities_for_room( self, room_id: str - ) -> List[Dict[str, Any]]: - """Get list of forward extremities for a room.""" + ) -> List[Tuple[str, int, int, Optional[int]]]: + """ + Get list of forward extremities for a room. + + Returns: + A list of tuples of event_id, state_group, depth, and received_ts. + """ def get_forward_extremities_for_room_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, int, int, Optional[int]]]: sql = """ SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities @@ -106,7 +111,7 @@ class EventForwardExtremitiesStore( """ txn.execute(sql, (room_id,)) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall()) return await self.db_pool.runInteraction( "get_forward_extremities_for_room", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index f82140b2e8..aeb3db596c 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_ids( self, before_ts: int, include_quarantined_media: bool - ) -> List[Dict[str, str]]: + ) -> List[Tuple[str, str, str]]: """ Retrieve a list of server name, media ID tuples from the remote media cache. @@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): A list of tuples containing: * The server name of homeserver where the media originates from, * The ID of the media. + * The filesystem ID. + """ + + sql = """ + SELECT media_origin, media_id, filesystem_id + FROM remote_media_cache + WHERE last_access_ts < ? """ - sql = ( - "SELECT media_origin, media_id, filesystem_id" - " FROM remote_media_cache" - " WHERE last_access_ts < ?" - ) if include_quarantined_media is False: # Only include media that has not been quarantined @@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): AND quarantined_by IS NULL """ - return await self.db_pool.execute( - "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts + return cast( + List[Tuple[str, str, str]], + await self.db_pool.execute("get_remote_media_ids", sql, before_ts), ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index b0ef7be155..e09ab21593 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -151,6 +151,22 @@ class ThreepidResult: added_at: int +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThreepidValidationSession: + address: str + """address of the 3pid""" + medium: str + """medium of the 3pid""" + client_secret: str + """a secret provided by the client for this validation session""" + session_id: str + """ID of the validation session""" + last_send_attempt: int + """a number serving to dedupe send attempts for this session""" + validated_at: Optional[int] + """timestamp of when this session was validated if so""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): address: Optional[str] = None, sid: Optional[str] = None, validated: Optional[bool] = True, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThreepidValidationSession]: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata @@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): perform no filtering Returns: - A dict containing the following: - * address - address of the 3pid - * medium - medium of the 3pid - * client_secret - a secret provided by the client for this validation session - * session_id - ID of the validation session - * send_attempt - a number serving to dedupe send attempts for this session - * validated_at - timestamp of when this session was validated if so - - Otherwise None if a validation session is not found + A ThreepidValidationSession or None if a validation session is not found """ if not client_secret: raise SynapseError( @@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): def get_threepid_validation_session_txn( txn: LoggingTransaction, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThreepidValidationSession]: sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at @@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.db_pool.cursor_to_dict(txn) - if not rows: + row = txn.fetchone() + if not row: return None - return rows[0] + return ThreepidValidationSession( + address=row[0], + session_id=row[1], + medium=row[2], + client_secret=row[3], + last_send_attempt=row[4], + validated_at=row[5], + ) return await self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index a1627dffb7..67e149b586 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): like_clause = "%:" + host rows = await self.db_pool.execute( - "is_host_joined", None, sql, membership, room_id, like_clause + "is_host_joined", sql, membership, room_id, like_clause ) if not rows: @@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND forgotten = 0; """ - rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + rows = await self.db_pool.execute("is_forgotten_room", sql, room_id) # `count(*)` returns always an integer # If any rows still exist it means someone has not forgotten this room yet diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1d69c4a5f0..dbde9130c6 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -26,6 +26,7 @@ from typing import ( Set, Tuple, Union, + cast, ) import attr @@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = await self.db_pool.execute( - "search_msgs", self.db_pool.cursor_to_dict, sql, *args + # List of tuples of (rank, room_id, event_id). + results = cast( + List[Tuple[Union[int, float], str, str]], + await self.db_pool.execute("search_msgs", sql, *args), ) - results = list(filter(lambda row: row["room_id"] in room_ids, results)) + results = list(filter(lambda row: row[1] in room_ids, results)) # We set redact_behaviour to block here to prevent redacted events being returned in # search results (which is a data leak) events = await self.get_events_as_list( # type: ignore[attr-defined] - [r["event_id"] for r in results], + [r[2] for r in results], redact_behaviour=EventRedactBehaviour.block, ) @@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = await self.db_pool.execute( - "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args + # List of tuples of (room_id, count). + count_results = cast( + List[Tuple[str, int]], + await self.db_pool.execute("search_rooms_count", count_sql, *count_args), ) - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + count = sum(row[1] for row in count_results if row[0] in room_ids) return { "results": [ - {"event": event_map[r["event_id"]], "rank": r["rank"]} + {"event": event_map[r[2]], "rank": r[0]} for r in results - if r["event_id"] in event_map + if r[2] in event_map ], "highlights": highlights, "count": count, @@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_query = search_term sql = """ SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, - origin_server_ts, stream_ordering, room_id, event_id + room_id, event_id, origin_server_ts, stream_ordering FROM event_search WHERE vector @@ websearch_to_tsquery('english', ?) AND """ @@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore): # mypy expects to append only a `str`, not an `int` args.append(limit) - results = await self.db_pool.execute( - "search_rooms", self.db_pool.cursor_to_dict, sql, *args + # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering). + results = cast( + List[Tuple[Union[int, float], str, str, int, int]], + await self.db_pool.execute("search_rooms", sql, *args), ) - results = list(filter(lambda row: row["room_id"] in room_ids, results)) + results = list(filter(lambda row: row[1] in room_ids, results)) # We set redact_behaviour to block here to prevent redacted events being returned in # search results (which is a data leak) events = await self.get_events_as_list( # type: ignore[attr-defined] - [r["event_id"] for r in results], + [r[2] for r in results], redact_behaviour=EventRedactBehaviour.block, ) @@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = await self.db_pool.execute( - "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args + # List of tuples of (room_id, count). + count_results = cast( + List[Tuple[str, int]], + await self.db_pool.execute("search_rooms_count", count_sql, *count_args), ) - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + count = sum(row[1] for row in count_results if row[0] in room_ids) return { "results": [ { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" - % (r["origin_server_ts"], r["stream_ordering"]), + "event": event_map[r[2]], + "rank": r[0], + "pagination_token": "%s,%s" % (r[3], r[4]), } for r in results - if r["event_id"] in event_map + if r[2] in event_map ], "highlights": highlights, "count": count, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 5b2d0ba870..e96c9b0486 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore): order_by: Optional[str] = UserSortOrder.USER_ID.value, direction: Direction = Direction.FORWARDS, search_term: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]: """Function to retrieve a paginated list of users and their uploaded local media (size and number). This will return a json list of users and the total number of users matching the filter criteria. @@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore): order_by: the sort order of the returned list direction: sort ascending or descending search_term: a string to filter user names by + Returns: - A list of user dicts and an integer representing the total number of - users that exist given this query + A tuple of: + A list of tuples of user information (the user ID, displayname, + total number of media, total length of media) and + + An integer representing the total number of users that exist + given this query """ def get_users_media_usage_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]: filters = [] args: list = [] @@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore): args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) + users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall()) return users, count diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 872df6bda1..2225f8272d 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ row = await self.db_pool.execute( - "get_current_topological_token", None, sql, room_id, room_id, stream_key + "get_current_topological_token", sql, room_id, room_id, stream_key ) return row[0][0] if row else 0 @@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = await self.db_pool.execute( "get_timeline_gaps", - None, sql, room_id, from_token.stream if from_token else 0, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index c4a6475060..fecddb4144 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destination: Optional[str] = None, order_by: str = DestinationSortOrder.DESTINATION.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[ + List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]], + int, + ]: """Function to retrieve a paginated list of destinations. This will return a json list of destinations and the total number of destinations matching the filter criteria. @@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): order_by: the sort order of the returned list direction: sort ascending or descending Returns: - A tuple of a list of mappings from destination to information + A tuple of a list of tuples of destination information: + * destination + * retry_last_ts + * retry_interval + * failure_ts + * last_successful_stream_ordering and a count of total destinations. """ def get_destinations_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[ + List[ + Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]] + ], + int, + ]: order_by_column = DestinationSortOrder(order_by).value if direction == Direction.BACKWARDS: @@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): LIMIT ? OFFSET ? """ txn.execute(sql, args + [limit, start]) - destinations = self.db_pool.cursor_to_dict(txn) + destinations = cast( + List[ + Tuple[ + str, Optional[int], Optional[int], Optional[int], Optional[int] + ] + ], + txn.fetchall(), + ) return destinations, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 23eb92c514..a9f5d68b63 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): raise Exception("Unrecognized database engine") results = cast( - List[UserProfile], - await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args - ), + List[Tuple[str, Optional[str], Optional[str]]], + await self.db_pool.execute("search_user_dir", sql, *args), ) limited = len(results) > limit - return {"limited": limited, "results": results[0:limit]} + return { + "limited": limited, + "results": [ + {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]} + for r in results[0:limit] + ], + } def _filter_text_for_index(text: str) -> str: diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 6ff533a129..0f9c550b27 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): if max_group is None: rows = await self.db_pool.execute( "_background_deduplicate_state", - None, "SELECT coalesce(max(id), 0) FROM state_groups", ) max_group = rows[0][0] diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 75ae740b43..08214b0013 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_id, stream_ordering = self.get_success( self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", - None, """ SELECT event_id, stream_ordering FROM destination_rooms dr diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 3f5bfa09d4..67ea640902 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): ); """ self.get_success( - self.store.db_pool.execute( - "test_not_null_constraint", lambda _: None, table_sql + self.store.db_pool.runInteraction( + "test_not_null_constraint", lambda txn: txn.execute(table_sql) ) ) @@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): # using SQLite. index_sql = "CREATE INDEX test_index ON test_constraint(a)" self.get_success( - self.store.db_pool.execute( - "test_not_null_constraint", lambda _: None, index_sql + self.store.db_pool.runInteraction( + "test_not_null_constraint", lambda txn: txn.execute(index_sql) ) ) @@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): ); """ self.get_success( - self.store.db_pool.execute( - "test_foreign_key_constraint", lambda _: None, base_sql + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", lambda txn: txn.execute(base_sql) ) ) self.get_success( - self.store.db_pool.execute( - "test_foreign_key_constraint", lambda _: None, table_sql + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", lambda txn: txn.execute(table_sql) ) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 95f99f4130..6afb5403bd 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.db_pool.execute( - "", None, "SELECT full_user_id from profiles ORDER BY full_user_id" + "", "SELECT full_user_id from profiles ORDER BY full_user_id" ) ) self.assertEqual(len(res), len(expected_values)) diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py index d4637d9d1e..2da6a018e8 100644 --- a/tests/storage/test_user_filters.py +++ b/tests/storage/test_user_filters.py @@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.db_pool.execute( - "", None, "SELECT full_user_id from user_filters ORDER BY full_user_id" + "", "SELECT full_user_id from user_filters ORDER BY full_user_id" ) ) self.assertEqual(len(res), len(expected_values)) -- cgit 1.5.1