From dc51d8ffaf4d392be2f36c4d36625352b09c55c9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 11:22:25 -0500 Subject: Add a background task to purge unused chain IDs. (#9542) This is a companion change to apply the fix in #9498 / 922788c6043138165c025c78effeda87de842bab to previously purged rooms. --- .../storage/databases/main/events_bg_updates.py | 79 ++++++++++++++++++++++ synapse/storage/databases/main/purge_events.py | 8 +-- .../delta/59/10delete_purged_chain_cover.sql | 17 +++++ 3 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index cb6b1f8a0c..73e69d4cb1 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._chain_cover_index, ) + self.db_pool.updates.register_background_update_handler( + "purged_chain_cover", + self._purged_chain_cover_index, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): processed_count=count, finished_room_map=finished_rooms, ) + + async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int: + """ + A background updates that iterates over the chain cover and deletes the + chain cover for events that have been purged. + + This may be due to fully purging a room or via setting a retention policy. + """ + current_event_id = progress.get("current_event_id", "") + + def purged_chain_cover_txn(txn) -> int: + # The event ID from events will be null if the chain ID / sequence + # number points to a purged event. + sql = """ + SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL + FROM event_auth_chains + LEFT JOIN events AS e USING (event_id) + WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ? + """ + txn.execute(sql, (current_event_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + # The event IDs and chain IDs / sequence numbers where the event has + # been purged. + unreferenced_event_ids = [] + unreferenced_chain_id_tuples = [] + event_id = "" + for event_id, chain_id, sequence_number, has_event in rows: + if not has_event: + unreferenced_event_ids.append(event_id) + unreferenced_chain_id_tuples.append((chain_id, sequence_number)) + + # Delete the unreferenced auth chains from event_auth_chain_links and + # event_auth_chains. + txn.executemany( + """ + DELETE FROM event_auth_chains WHERE event_id = ? + """, + unreferenced_event_ids, + ) + # We should also delete matching target_*, but there is no index on + # target_chain_id. Hopefully any purged events are due to a room + # being fully purged and they will be removed from the origin_* + # searches. + txn.executemany( + """ + DELETE FROM event_auth_chain_links WHERE + origin_chain_id = ? AND origin_sequence_number = ? + """, + unreferenced_chain_id_tuples, + ) + + progress = { + "current_event_id": event_id, + } + + self.db_pool.updates._background_update_progress_txn( + txn, "purged_chain_cover", progress + ) + + return len(rows) + + result = await self.db_pool.runInteraction( + "_purged_chain_cover_index", + purged_chain_cover_txn, + ) + + if not result: + await self.db_pool.updates._end_background_update("purged_chain_cover") + + return result diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 0836e4af49..41f4fe7f95 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): txn.executemany( """ DELETE FROM event_auth_chain_links WHERE - (origin_chain_id = ? AND origin_sequence_number = ?) OR - (target_chain_id = ? AND target_sequence_number = ?) + origin_chain_id = ? AND origin_sequence_number = ? """, - ( - (chain_id, seq_num, chain_id, seq_num) - for (chain_id, seq_num) in referenced_chain_id_tuples - ), + referenced_chain_id_tuples, ) # Now we delete tables which lack an index on room_id but have one on event_id diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql new file mode 100644 index 0000000000..87cb1f3cfd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5910, 'purged_chain_cover', '{}'); -- cgit 1.5.1 From 918f6ed827ceabc052eddba15d2e6aeefe36be23 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Mar 2021 08:55:52 -0500 Subject: Fix a bug in the background task for purging chain cover. (#9583) --- changelog.d/9583.bugfix | 1 + synapse/storage/databases/main/events_bg_updates.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9583.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/9583.bugfix b/changelog.d/9583.bugfix new file mode 100644 index 0000000000..51b1876f3b --- /dev/null +++ b/changelog.d/9583.bugfix @@ -0,0 +1 @@ +Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 73e69d4cb1..78367ea58d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -969,7 +969,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): event_id = "" for event_id, chain_id, sequence_number, has_event in rows: if not has_event: - unreferenced_event_ids.append(event_id) + unreferenced_event_ids.append((event_id,)) unreferenced_chain_id_tuples.append((chain_id, sequence_number)) # Delete the unreferenced auth chains from event_auth_chain_links and -- cgit 1.5.1 From 2a99cc6524808380d2353ffff013cfa6fdfc09db Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Mar 2021 09:57:59 -0500 Subject: Use the chain cover index in get_auth_chain_ids. (#9576) This uses a simplified version of get_chain_cover_difference to calculate auth chain of events. --- changelog.d/9576.misc | 1 + synapse/federation/federation_server.py | 6 +- synapse/handlers/federation.py | 6 +- synapse/storage/databases/main/event_federation.py | 148 ++++++++++++++++++++- tests/storage/test_event_federation.py | 76 ++++++++++- 5 files changed, 226 insertions(+), 11 deletions(-) create mode 100644 changelog.d/9576.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc new file mode 100644 index 0000000000..bc257d05b7 --- /dev/null +++ b/changelog.d/9576.misc @@ -0,0 +1 @@ +Improve efficiency of calculating the auth chain in large rooms. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ffc735ba25..06c5e7a9e0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -447,7 +447,7 @@ class FederationServer(FederationBase): async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) + auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( @@ -460,7 +460,9 @@ class FederationServer(FederationBase): else: pdus = (await self.state.get_current_state(room_id)).values() - auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + auth_chain = await self.store.get_auth_chain( + room_id, [pdu.event_id for pdu in pdus] + ) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ead626a4d..3fe02b7195 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1317,7 +1317,7 @@ class FederationHandler(BaseHandler): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + event.room_id, list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -1580,7 +1580,7 @@ class FederationHandler(BaseHandler): prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(state_ids) + auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) state = await self.store.get_events(list(prev_state_ids.values())) @@ -2219,7 +2219,7 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + room_id, list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 18ddb92fcc..332193ad1c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) # type: LruCache[str, List[Tuple[str, int]]] async def get_auth_chain( - self, event_ids: Collection[str], include_given: bool = False + self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result @@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas list of events """ event_ids = await self.get_auth_chain_ids( - event_ids, include_given=include_given + room_id, event_ids, include_given=include_given ) return await self.get_events_as_list(event_ids) async def get_auth_chain_ids( self, + room_id: str, event_ids: Collection[str], include_given: bool = False, ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result Returns: - An awaitable which resolve to a list of event_ids + list of event_ids """ + + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_ids_chains", + self._get_auth_chain_ids_using_cover_index_txn, + room_id, + event_ids, + include_given, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, @@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas include_given, ) + def _get_auth_chain_ids_using_cover_index_txn( + self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + ) -> List[str]: + """Calculates the auth chain IDs using the chain index.""" + + # First we look up the chain ID/sequence numbers for the given events. + + initial_events = set(event_ids) + + # All the events that we've found that are reachable from the events. + seen_events = set() # type: Set[str] + + # A map from chain ID to max sequence number of the given events. + event_chains = {} # type: Dict[int, int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + seen_events.add(event_id) + event_chains[chain_id] = max( + sequence_number, event_chains.get(chain_id, 0) + ) + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(seen_events) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Now we look up all links for the chains we have, adding chains that + # are reachable from any event. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # A map from chain ID to max sequence number *reachable* from any event ID. + chains = {} # type: Dict[int, int] + + # Add all linked chains reachable from initial set of chains. + for batch in batch_iter(event_chains, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= event_chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, + chains.get(target_chain_id, 0), + ) + + # Add the initial set of chains, excluding the sequence corresponding to + # initial event. + for chain_id, seq_no in event_chains.items(): + chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0)) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* event ID. Events with a sequence less than that are in the + # auth chain. + if include_given: + results = initial_events + else: + results = set() + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + """ + + rows = txn.execute_values(sql, chains.items()) + results.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND sequence_number <= ? + """ + for chain_id, max_no in chains.items(): + txn.execute(sql, (chain_id, max_no)) + results.update(r for r, in txn) + + return list(results) + def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool ) -> List[str]: + """Calculates the auth chain IDs. + + This is used when we don't have a cover index for the room. + """ if include_given: results = set(event_ids) else: diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 06000f81a6..d597d712d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } - # Mark the room as not having a cover index + # Mark the room as maybe having a cover index. def store_room(txn): self.store.db_pool.simple_insert_txn( @@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) + return room_id + + @parameterized.expand([(True,), (False,)]) + def test_auth_chain_ids(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + + # a and b have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["a", "b"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + # d and e have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"])) + self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) + self.assertEqual(auth_chain_ids, ["k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) + self.assertEqual(auth_chain_ids, ["j"]) + + # j and k have no parents. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) + self.assertEqual(auth_chain_ids, []) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) + self.assertEqual(auth_chain_ids, []) + + # More complex input sequences. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["h", "i"]) + ) + self.assertCountEqual(auth_chain_ids, ["k", "j"]) + + # e gets returned even though include_given is false, but it is in the + # auth chain of b. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "e"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + # Test include_given. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["i"], include_given=True) + ) + self.assertCountEqual(auth_chain_ids, ["i", "j"]) + + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + # Now actually test that various combinations give the right result: difference = self.get_success( -- cgit 1.5.1 From a7a379006651ea219c2618c0a47e8f4053a9e769 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 10 Mar 2021 18:15:56 +0000 Subject: Convert Requester to attrs (#9586) ... because namedtuples suck Fix up a couple of other annotations to keep mypy happy. --- changelog.d/9586.misc | 1 + synapse/handlers/auth.py | 5 ++- synapse/rest/media/v1/media_repository.py | 3 +- synapse/storage/databases/main/registration.py | 6 +-- synapse/types.py | 57 +++++++++++++------------- 5 files changed, 37 insertions(+), 35 deletions(-) create mode 100644 changelog.d/9586.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9586.misc b/changelog.d/9586.misc new file mode 100644 index 0000000000..2def9d5f55 --- /dev/null +++ b/changelog.d/9586.misc @@ -0,0 +1 @@ +Convert `synapse.types.Requester` to an `attrs` class. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bec0c615d4..fb5f8118f0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -337,7 +337,8 @@ class AuthHandler(BaseHandler): user is too high to proceed """ - + if not requester.access_token_id: + raise ValueError("Cannot validate a user without an access token") if self._ui_auth_session_timeout: last_validated = await self.store.get_access_token_last_validated( requester.access_token_id @@ -1213,7 +1214,7 @@ class AuthHandler(BaseHandler): async def delete_access_tokens_for_user( self, user_id: str, - except_token_id: Optional[str] = None, + except_token_id: Optional[int] = None, device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0641924f18..8b4841ed5d 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -35,6 +35,7 @@ from synapse.api.errors import ( from synapse.config._base import ConfigError from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -145,7 +146,7 @@ class MediaRepository: upload_name: Optional[str], content: IO, content_length: int, - auth_user: str, + auth_user: UserID, ) -> str: """Store uploaded content for a local user and return the mxc URL diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 61a7556e56..eba66ff352 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr @@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): async def user_delete_access_tokens( self, user_id: str, - except_token_id: Optional[str] = None, + except_token_id: Optional[int] = None, device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ @@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] + values = [v for _, v in items] # type: List[Union[str, int]] if except_token_id: where_clause += " AND id != ?" values.append(except_token_id) diff --git a/synapse/types.py b/synapse/types.py index 0216d213c7..b08ce90140 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -83,33 +83,32 @@ class ISynapseReactor( """The interfaces necessary for Synapse to function.""" -class Requester( - namedtuple( - "Requester", - [ - "user", - "access_token_id", - "is_guest", - "shadow_banned", - "device_id", - "app_service", - "authenticated_entity", - ], - ) -): +@attr.s(frozen=True, slots=True) +class Requester: """ Represents the user making a request Attributes: - user (UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this + user: id of the user making the request + access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - shadow_banned (bool): True if the user making this request has been shadow-banned. - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. """ + user = attr.ib(type="UserID") + access_token_id = attr.ib(type=Optional[int]) + is_guest = attr.ib(type=bool) + shadow_banned = attr.ib(type=bool) + device_id = attr.ib(type=Optional[str]) + app_service = attr.ib(type=Optional["ApplicationService"]) + authenticated_entity = attr.ib(type=str) + def serialize(self): """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -157,23 +156,23 @@ class Requester( def create_requester( user_id: Union[str, "UserID"], access_token_id: Optional[int] = None, - is_guest: Optional[bool] = False, - shadow_banned: Optional[bool] = False, + is_guest: bool = False, + shadow_banned: bool = False, device_id: Optional[str] = None, app_service: Optional["ApplicationService"] = None, authenticated_entity: Optional[str] = None, -): +) -> Requester: """ Create a new ``Requester`` object Args: - user_id (str|UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - shadow_banned (bool): True if the user making this request is shadow-banned. - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user authenticated_entity: The entity that authenticated when making the request. This is different to the user_id when an admin user or the server is "puppeting" the user. -- cgit 1.5.1 From af2248f8bf1cf11a230577650e84885387f1920f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 15 Mar 2021 13:51:02 +0000 Subject: Optimise missing prev_event handling (#9601) Background: When we receive incoming federation traffic, and notice that we are missing prev_events from the incoming traffic, first we do a `/get_missing_events` request, and then if we still have missing prev_events, we set up new backwards-extremities. To do that, we need to make a `/state_ids` request to ask the remote server for the state at those prev_events, and then we may need to then ask the remote server for any events in that state which we don't already have, as well as the auth events for those missing state events, so that we can auth them. This PR attempts to optimise the processing of that state request. The `state_ids` API returns a list of the state events, as well as a list of all the auth events for *all* of those state events. The optimisation comes from the observation that we are currently loading all of those auth events into memory at the start of the operation, but we almost certainly aren't going to need *all* of the auth events. Rather, we can check that we have them, and leave the actual load into memory for later. (Ideally the federation API would tell us which auth events we're actually going to need, but it doesn't.) The effect of this is to reduce the number of events that I need to load for an event in Matrix HQ from about 60000 to about 22000, which means it can stay in my in-memory cache, whereas previously the sheer number of events meant that all 60K events had to be loaded from db for each request, due to the amount of cache churn. (NB I've already tripled the size of the cache from its default of 10K). Unfortunately I've ended up basically C&Ping `_get_state_for_room` and `_get_events_from_store_or_dest` into a new method, because `_get_state_for_room` is also called during backfill, which expects the auth events to be returned, so the same tricks don't work. That said, I don't really know why that codepath is completely different (ultimately we're doing the same thing in setting up a new backwards extremity) so I've left a TODO suggesting that we clean it up. --- changelog.d/9601.feature | 1 + synapse/handlers/federation.py | 152 ++++++++++++++++++++---- synapse/storage/databases/main/events_worker.py | 12 +- 3 files changed, 137 insertions(+), 28 deletions(-) create mode 100644 changelog.d/9601.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/9601.feature b/changelog.d/9601.feature new file mode 100644 index 0000000000..5078d63ffa --- /dev/null +++ b/changelog.d/9601.feature @@ -0,0 +1 @@ +Optimise handling of incomplete room history for incoming federation. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1d20c441f3..598a66f74c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -353,17 +353,16 @@ class FederationHandler(BaseHandler): # Ask the remote server for the states we don't # know about for p in prevs - seen: - logger.info( - "Requesting state at missing prev_event %s", - event_id, - ) + logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = await self._get_state_for_room( - origin, room_id, p, include_event_in_state=True + remote_state = ( + await self._get_state_after_missing_prev_event( + origin, room_id, p + ) ) remote_state_map = { @@ -539,7 +538,6 @@ class FederationHandler(BaseHandler): destination: str, room_id: str, event_id: str, - include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. @@ -547,11 +545,9 @@ class FederationHandler(BaseHandler): destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. - include_event_in_state: if true, the event itself will be included in the - returned state event list. Returns: - A list of events in the state, possibly including the event itself, and + A list of events in the state, not including the event itself, and a list of events in the auth chain for the given event. """ ( @@ -563,9 +559,6 @@ class FederationHandler(BaseHandler): desired_events = set(state_event_ids + auth_event_ids) - if include_event_in_state: - desired_events.add(event_id) - event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -582,13 +575,6 @@ class FederationHandler(BaseHandler): event_map[e_id] for e_id in state_event_ids if e_id in event_map ] - if include_event_in_state: - remote_event = event_map.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) @@ -662,6 +648,131 @@ class FederationHandler(BaseHandler): return fetched_events + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + # TODO: This function is basically the same as _get_state_for_room. Can + # we make backfill() use it, rather than having two code paths? I think the + # only difference is that backfill() persists the prev events separately. + + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_desired_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + async def _process_received_pdu( self, origin: str, @@ -841,7 +952,6 @@ class FederationHandler(BaseHandler): destination=dest, room_id=room_id, event_id=e_id, - include_event_in_state=False, ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index edbe42f2bf..c04e162ccc 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools + import logging import threading from collections import namedtuple @@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore): Returns: set[str]: The events we have already seen. """ - results = set() + # if the event cache contains the event, obviously we've seen it. + results = {x for x in event_ids if self._get_event_cache.contains(x)} def have_seen_events_txn(txn, chunk): sql = "SELECT event_id FROM events as e WHERE " @@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore): txn.database_engine, "e.event_id", chunk ) txn.execute(sql + clause, args) - for (event_id,) in txn: - results.add(event_id) + results.update(row[0] for row in txn) - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + for chunk in batch_iter((x for x in event_ids if x not in results), 100): await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) -- cgit 1.5.1 From 026503fa3b90c03996a64be95432e345434b4a82 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Mar 2021 14:42:40 +0000 Subject: Don't go into federation catch up mode so easily (#9561) Federation catch up mode is very inefficient if the number of events that the remote server has missed is small, since handling gaps can be very expensive, c.f. #9492. Instead of going into catch up mode whenever we see an error, we instead do so only if we've backed off from trying the remote for more than an hour (the assumption being that in such a case it is more than a transient failure). --- changelog.d/9561.misc | 1 + synapse/federation/sender/per_destination_queue.py | 287 ++++++++++++--------- synapse/federation/sender/transaction_manager.py | 48 +--- synapse/storage/databases/main/transactions.py | 10 +- tests/federation/test_federation_catch_up.py | 3 +- 5 files changed, 190 insertions(+), 159 deletions(-) create mode 100644 changelog.d/9561.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9561.misc b/changelog.d/9561.misc new file mode 100644 index 0000000000..6c529a82ee --- /dev/null +++ b/changelog.d/9561.misc @@ -0,0 +1 @@ +Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index deb519f3ef..cc0d765e5f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -17,6 +17,7 @@ import datetime import logging from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +import attr from prometheus_client import Counter from synapse.api.errors import ( @@ -93,6 +94,10 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False + # Flag to signal to any running transmission loop that there is new data + # queued up to be sent. + self._new_data_to_send = False + # True whilst we are sending events that the remote homeserver missed # because it was unreachable. We start in this state so we can perform # catch-up at startup. @@ -108,7 +113,7 @@ class PerDestinationQueue: # destination (we are the only updater so this is safe) self._last_successful_stream_ordering = None # type: Optional[int] - # a list of pending PDUs + # a queue of pending PDUs self._pending_pdus = [] # type: List[EventBase] # XXX this is never actually used: see @@ -208,6 +213,10 @@ class PerDestinationQueue: transaction in the background. """ + # Mark that we (may) have new things to send, so that any running + # transmission loop will recheck whether there is stuff to send. + self._new_data_to_send = True + if self.transmission_loop_running: # XXX: this can get stuck on by a never-ending # request at which point pending_pdus just keeps growing. @@ -250,125 +259,41 @@ class PerDestinationQueue: pending_pdus = [] while True: - # We have to keep 2 free slots for presence and rr_edus - limit = MAX_EDUS_PER_TRANSACTION - 2 - - device_update_edus, dev_list_id = await self._get_device_update_edus( - limit - ) - - limit -= len(device_update_edus) - - ( - to_device_edus, - device_stream_id, - ) = await self._get_to_device_message_edus(limit) - - pending_edus = device_update_edus + to_device_edus - - # BEGIN CRITICAL SECTION - # - # In order to avoid a race condition, we need to make sure that - # the following code (from popping the queues up to the point - # where we decide if we actually have any pending messages) is - # atomic - otherwise new PDUs or EDUs might arrive in the - # meantime, but not get sent because we hold the - # transmission_loop_running flag. - - pending_pdus = self._pending_pdus + self._new_data_to_send = False - # We can only include at most 50 PDUs per transactions - pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] + async with _TransactionQueueManager(self) as ( + pending_pdus, + pending_edus, + ): + if not pending_pdus and not pending_edus: + logger.debug("TX [%s] Nothing to send", self._destination) + + # If we've gotten told about new things to send during + # checking for things to send, we try looking again. + # Otherwise new PDUs or EDUs might arrive in the meantime, + # but not get sent because we hold the + # `transmission_loop_running` flag. + if self._new_data_to_send: + continue + else: + return - pending_edus.extend(self._get_rr_edus(force_flush=False)) - pending_presence = self._pending_presence - self._pending_presence = {} - if pending_presence: - pending_edus.append( - Edu( - origin=self._server_name, - destination=self._destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self._clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, + if pending_pdus: + logger.debug( + "TX [%s] len(pending_pdus_by_dest[dest]) = %d", + self._destination, + len(pending_pdus), ) - ) - pending_edus.extend( - self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self._pending_edus_keyed - ): - _, val = self._pending_edus_keyed.popitem() - pending_edus.append(val) - - if pending_pdus: - logger.debug( - "TX [%s] len(pending_pdus_by_dest[dest]) = %d", - self._destination, - len(pending_pdus), + await self._transaction_manager.send_new_transaction( + self._destination, pending_pdus, pending_edus ) - if not pending_pdus and not pending_edus: - logger.debug("TX [%s] Nothing to send", self._destination) - self._last_device_stream_id = device_stream_id - return - - # if we've decided to send a transaction anyway, and we have room, we - # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self._get_rr_edus(force_flush=True)) - - # END CRITICAL SECTION - - success = await self._transaction_manager.send_new_transaction( - self._destination, pending_pdus, pending_edus - ) - if success: sent_transactions_counter.inc() sent_edus_counter.inc(len(pending_edus)) for edu in pending_edus: sent_edus_by_type.labels(edu.edu_type).inc() - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if to_device_edus: - await self._store.delete_device_msgs_for_remote( - self._destination, device_stream_id - ) - # also mark the device updates as sent - if device_update_edus: - logger.info( - "Marking as sent %r %r", self._destination, dev_list_id - ) - await self._store.mark_as_sent_devices_by_remote( - self._destination, dev_list_id - ) - - self._last_device_stream_id = device_stream_id - self._last_device_list_stream_id = dev_list_id - - if pending_pdus: - # we sent some PDUs and it was successful, so update our - # last_successful_stream_ordering in the destinations table. - final_pdu = pending_pdus[-1] - last_successful_stream_ordering = ( - final_pdu.internal_metadata.stream_ordering - ) - assert last_successful_stream_ordering - await self._store.set_destination_last_successful_stream_ordering( - self._destination, last_successful_stream_ordering - ) - else: - break except NotRetryingDestination as e: logger.debug( "TX [%s] not ready for retry yet (next retry at %s) - " @@ -401,7 +326,7 @@ class PerDestinationQueue: self._pending_presence = {} self._pending_rrs = {} - self._start_catching_up() + self._start_catching_up() except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -412,7 +337,6 @@ class PerDestinationQueue: e, ) - self._start_catching_up() except RequestSendFailed as e: logger.warning( "TX [%s] Failed to send transaction: %s", self._destination, e @@ -422,16 +346,12 @@ class PerDestinationQueue: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False @@ -499,13 +419,10 @@ class PerDestinationQueue: rooms = [p.room_id for p in catchup_pdus] logger.info("Catching up rooms to %s: %r", self._destination, rooms) - success = await self._transaction_manager.send_new_transaction( + await self._transaction_manager.send_new_transaction( self._destination, catchup_pdus, [] ) - if not success: - return - sent_transactions_counter.inc() final_pdu = catchup_pdus[-1] self._last_successful_stream_ordering = cast( @@ -584,3 +501,135 @@ class PerDestinationQueue: """ self._catching_up = True self._pending_pdus = [] + + +@attr.s(slots=True) +class _TransactionQueueManager: + """A helper async context manager for pulling stuff off the queues and + tracking what was last successfully sent, etc. + """ + + queue = attr.ib(type=PerDestinationQueue) + + _device_stream_id = attr.ib(type=Optional[int], default=None) + _device_list_id = attr.ib(type=Optional[int], default=None) + _last_stream_ordering = attr.ib(type=Optional[int], default=None) + _pdus = attr.ib(type=List[EventBase], factory=list) + + async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: + # First we calculate the EDUs we want to send, if any. + + # We start by fetching device related EDUs, i.e device updates and to + # device messages. We have to keep 2 free slots for presence and rr_edus. + limit = MAX_EDUS_PER_TRANSACTION - 2 + + device_update_edus, dev_list_id = await self.queue._get_device_update_edus( + limit + ) + + if device_update_edus: + self._device_list_id = dev_list_id + else: + self.queue._last_device_list_stream_id = dev_list_id + + limit -= len(device_update_edus) + + ( + to_device_edus, + device_stream_id, + ) = await self.queue._get_to_device_message_edus(limit) + + if to_device_edus: + self._device_stream_id = device_stream_id + else: + self.queue._last_device_stream_id = device_stream_id + + pending_edus = device_update_edus + to_device_edus + + # Now add the read receipt EDU. + pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) + + # And presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} + + # Finally add any other types of EDUs if there is room. + pending_edus.extend( + self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) + ) + while ( + len(pending_edus) < MAX_EDUS_PER_TRANSACTION + and self.queue._pending_edus_keyed + ): + _, val = self.queue._pending_edus_keyed.popitem() + pending_edus.append(val) + + # Now we look for any PDUs to send, by getting up to 50 PDUs from the + # queue + self._pdus = self.queue._pending_pdus[:50] + + if not self._pdus and not pending_edus: + return [], [] + + # if we've decided to send a transaction anyway, and we have room, we + # may as well send any pending RRs + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: + pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + + if self._pdus: + self._last_stream_ordering = self._pdus[ + -1 + ].internal_metadata.stream_ordering + assert self._last_stream_ordering + + return self._pdus, pending_edus + + async def __aexit__(self, exc_type, exc, tb): + if exc_type is not None: + # Failed to send transaction, so we bail out. + return + + # Successfully sent transactions, so we remove pending PDUs from the queue + if self._pdus: + self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :] + + # Succeeded to send the transaction so we record where we have sent up + # to in the various streams + + if self._device_stream_id: + await self.queue._store.delete_device_msgs_for_remote( + self.queue._destination, self._device_stream_id + ) + self.queue._last_device_stream_id = self._device_stream_id + + # also mark the device updates as sent + if self._device_list_id: + logger.info( + "Marking as sent %r %r", self.queue._destination, self._device_list_id + ) + await self.queue._store.mark_as_sent_devices_by_remote( + self.queue._destination, self._device_list_id + ) + self.queue._last_device_list_stream_id = self._device_list_id + + if self._last_stream_ordering: + # we sent some PDUs and it was successful, so update our + # last_successful_stream_ordering in the destinations table. + await self.queue._store.set_destination_last_successful_stream_ordering( + self.queue._destination, self._last_stream_ordering + ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 2a9cd063c4..07b740c2f2 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -69,15 +69,12 @@ class TransactionManager: destination: str, pdus: List[EventBase], edus: List[Edu], - ) -> bool: + ) -> None: """ Args: destination: The destination to send to (e.g. 'example.org') pdus: In-order list of PDUs to send edus: List of EDUs to send - - Returns: - True iff the transaction was successful """ # Make a transaction-sending opentracing span. This span follows on from @@ -96,8 +93,6 @@ class TransactionManager: edu.strip_context() with start_active_span_follows_from("send_transaction", span_contexts): - success = True - logger.debug("TX [%s] _attempt_new_transaction", destination) txn_id = str(self._next_txn_id) @@ -152,44 +147,29 @@ class TransactionManager: response = await self._transport_layer.send_transaction( transaction, json_data_cb ) - code = 200 except HttpResponseException as e: code = e.code response = e.response - if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", destination, txn_id, code - ) - raise e - - logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - - if code == 200: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warning( - "TX [%s] {%s} Remote returned error for %s: %s", - destination, - txn_id, - e_id, - r, - ) - else: - for p in pdus: + set_tag(tags.ERROR, True) + + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) + raise + + logger.info("TX [%s] {%s} got 200 response", destination, txn_id) + + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: logger.warning( - "TX [%s] {%s} Failed to send event %s", + "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, - p.event_id, + e_id, + r, ) - success = False - if success and pdus and destination in self._federation_metrics_domains: + if pdus and destination in self._federation_metrics_domains: last_pdu = pdus[-1] last_pdu_ts_metric.labels(server_name=destination).set( last_pdu.origin_server_ts / 1000 ) - - set_tag(tags.ERROR, not success) - return success diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index b921d63d30..0309661841 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore): self.db_pool.simple_upsert_many_txn( txn, - "destination_rooms", - ["destination", "room_id"], - rows, - ["stream_ordering"], - [(stream_ordering,)] * len(rows), + table="destination_rooms", + key_names=("destination", "room_id"), + key_values=rows, + value_names=["stream_ordering"], + value_values=[(stream_ordering,)] * len(rows), ) async def get_destination_last_successful_stream_ordering( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 1a3ccb263d..6f96cd7940 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable from tests.unittest import FederatingHomeserverTestCase, override_config @@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): else: data = json_cb() self.failed_pdus.extend(data["pdus"]) - raise IOError("Failed to connect because this is a test!") + raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination) def get_destination_room(self, room: str, destination: str = "host2") -> dict: """ -- cgit 1.5.1