From dc51d8ffaf4d392be2f36c4d36625352b09c55c9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 11:22:25 -0500 Subject: Add a background task to purge unused chain IDs. (#9542) This is a companion change to apply the fix in #9498 / 922788c6043138165c025c78effeda87de842bab to previously purged rooms. --- changelog.d/9542.bugfix | 1 + .../storage/databases/main/events_bg_updates.py | 79 ++++++++++++++++++++++ synapse/storage/databases/main/purge_events.py | 8 +-- .../delta/59/10delete_purged_chain_cover.sql | 17 +++++ 4 files changed, 99 insertions(+), 6 deletions(-) create mode 100644 changelog.d/9542.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql diff --git a/changelog.d/9542.bugfix b/changelog.d/9542.bugfix new file mode 100644 index 0000000000..51b1876f3b --- /dev/null +++ b/changelog.d/9542.bugfix @@ -0,0 +1 @@ +Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index cb6b1f8a0c..73e69d4cb1 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._chain_cover_index, ) + self.db_pool.updates.register_background_update_handler( + "purged_chain_cover", + self._purged_chain_cover_index, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): processed_count=count, finished_room_map=finished_rooms, ) + + async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int: + """ + A background updates that iterates over the chain cover and deletes the + chain cover for events that have been purged. + + This may be due to fully purging a room or via setting a retention policy. + """ + current_event_id = progress.get("current_event_id", "") + + def purged_chain_cover_txn(txn) -> int: + # The event ID from events will be null if the chain ID / sequence + # number points to a purged event. + sql = """ + SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL + FROM event_auth_chains + LEFT JOIN events AS e USING (event_id) + WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ? + """ + txn.execute(sql, (current_event_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + # The event IDs and chain IDs / sequence numbers where the event has + # been purged. + unreferenced_event_ids = [] + unreferenced_chain_id_tuples = [] + event_id = "" + for event_id, chain_id, sequence_number, has_event in rows: + if not has_event: + unreferenced_event_ids.append(event_id) + unreferenced_chain_id_tuples.append((chain_id, sequence_number)) + + # Delete the unreferenced auth chains from event_auth_chain_links and + # event_auth_chains. + txn.executemany( + """ + DELETE FROM event_auth_chains WHERE event_id = ? + """, + unreferenced_event_ids, + ) + # We should also delete matching target_*, but there is no index on + # target_chain_id. Hopefully any purged events are due to a room + # being fully purged and they will be removed from the origin_* + # searches. + txn.executemany( + """ + DELETE FROM event_auth_chain_links WHERE + origin_chain_id = ? AND origin_sequence_number = ? + """, + unreferenced_chain_id_tuples, + ) + + progress = { + "current_event_id": event_id, + } + + self.db_pool.updates._background_update_progress_txn( + txn, "purged_chain_cover", progress + ) + + return len(rows) + + result = await self.db_pool.runInteraction( + "_purged_chain_cover_index", + purged_chain_cover_txn, + ) + + if not result: + await self.db_pool.updates._end_background_update("purged_chain_cover") + + return result diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 0836e4af49..41f4fe7f95 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): txn.executemany( """ DELETE FROM event_auth_chain_links WHERE - (origin_chain_id = ? AND origin_sequence_number = ?) OR - (target_chain_id = ? AND target_sequence_number = ?) + origin_chain_id = ? AND origin_sequence_number = ? """, - ( - (chain_id, seq_num, chain_id, seq_num) - for (chain_id, seq_num) in referenced_chain_id_tuples - ), + referenced_chain_id_tuples, ) # Now we delete tables which lack an index on room_id but have one on event_id diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql new file mode 100644 index 0000000000..87cb1f3cfd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5910, 'purged_chain_cover', '{}'); -- cgit 1.5.1 From 67b979bfa185bd986afda21f27a0bf4fbcd49f9b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 14:41:02 -0500 Subject: Do not ignore the unpaddedbase64 module when type checking. (#9568) --- changelog.d/9568.misc | 1 + mypy.ini | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) create mode 100644 changelog.d/9568.misc diff --git a/changelog.d/9568.misc b/changelog.d/9568.misc new file mode 100644 index 0000000000..561963de93 --- /dev/null +++ b/changelog.d/9568.misc @@ -0,0 +1 @@ +Do not have mypy ignore type hints from unpaddedbase64. diff --git a/mypy.ini b/mypy.ini index f31cd432e6..e0685e097c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -117,9 +117,6 @@ ignore_missing_imports = True [mypy-saml2.*] ignore_missing_imports = True -[mypy-unpaddedbase64] -ignore_missing_imports = True - [mypy-canonicaljson] ignore_missing_imports = True -- cgit 1.5.1 From 918f6ed827ceabc052eddba15d2e6aeefe36be23 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Mar 2021 08:55:52 -0500 Subject: Fix a bug in the background task for purging chain cover. (#9583) --- changelog.d/9583.bugfix | 1 + synapse/storage/databases/main/events_bg_updates.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9583.bugfix diff --git a/changelog.d/9583.bugfix b/changelog.d/9583.bugfix new file mode 100644 index 0000000000..51b1876f3b --- /dev/null +++ b/changelog.d/9583.bugfix @@ -0,0 +1 @@ +Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 73e69d4cb1..78367ea58d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -969,7 +969,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): event_id = "" for event_id, chain_id, sequence_number, has_event in rows: if not has_event: - unreferenced_event_ids.append(event_id) + unreferenced_event_ids.append((event_id,)) unreferenced_chain_id_tuples.append((chain_id, sequence_number)) # Delete the unreferenced auth chains from event_auth_chain_links and -- cgit 1.5.1 From 2a99cc6524808380d2353ffff013cfa6fdfc09db Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Mar 2021 09:57:59 -0500 Subject: Use the chain cover index in get_auth_chain_ids. (#9576) This uses a simplified version of get_chain_cover_difference to calculate auth chain of events. --- changelog.d/9576.misc | 1 + synapse/federation/federation_server.py | 6 +- synapse/handlers/federation.py | 6 +- synapse/storage/databases/main/event_federation.py | 148 ++++++++++++++++++++- tests/storage/test_event_federation.py | 76 ++++++++++- 5 files changed, 226 insertions(+), 11 deletions(-) create mode 100644 changelog.d/9576.misc diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc new file mode 100644 index 0000000000..bc257d05b7 --- /dev/null +++ b/changelog.d/9576.misc @@ -0,0 +1 @@ +Improve efficiency of calculating the auth chain in large rooms. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ffc735ba25..06c5e7a9e0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -447,7 +447,7 @@ class FederationServer(FederationBase): async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) + auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( @@ -460,7 +460,9 @@ class FederationServer(FederationBase): else: pdus = (await self.state.get_current_state(room_id)).values() - auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + auth_chain = await self.store.get_auth_chain( + room_id, [pdu.event_id for pdu in pdus] + ) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ead626a4d..3fe02b7195 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1317,7 +1317,7 @@ class FederationHandler(BaseHandler): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + event.room_id, list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -1580,7 +1580,7 @@ class FederationHandler(BaseHandler): prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(state_ids) + auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) state = await self.store.get_events(list(prev_state_ids.values())) @@ -2219,7 +2219,7 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + room_id, list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 18ddb92fcc..332193ad1c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) # type: LruCache[str, List[Tuple[str, int]]] async def get_auth_chain( - self, event_ids: Collection[str], include_given: bool = False + self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result @@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas list of events """ event_ids = await self.get_auth_chain_ids( - event_ids, include_given=include_given + room_id, event_ids, include_given=include_given ) return await self.get_events_as_list(event_ids) async def get_auth_chain_ids( self, + room_id: str, event_ids: Collection[str], include_given: bool = False, ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result Returns: - An awaitable which resolve to a list of event_ids + list of event_ids """ + + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_ids_chains", + self._get_auth_chain_ids_using_cover_index_txn, + room_id, + event_ids, + include_given, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, @@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas include_given, ) + def _get_auth_chain_ids_using_cover_index_txn( + self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + ) -> List[str]: + """Calculates the auth chain IDs using the chain index.""" + + # First we look up the chain ID/sequence numbers for the given events. + + initial_events = set(event_ids) + + # All the events that we've found that are reachable from the events. + seen_events = set() # type: Set[str] + + # A map from chain ID to max sequence number of the given events. + event_chains = {} # type: Dict[int, int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + seen_events.add(event_id) + event_chains[chain_id] = max( + sequence_number, event_chains.get(chain_id, 0) + ) + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(seen_events) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Now we look up all links for the chains we have, adding chains that + # are reachable from any event. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # A map from chain ID to max sequence number *reachable* from any event ID. + chains = {} # type: Dict[int, int] + + # Add all linked chains reachable from initial set of chains. + for batch in batch_iter(event_chains, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= event_chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, + chains.get(target_chain_id, 0), + ) + + # Add the initial set of chains, excluding the sequence corresponding to + # initial event. + for chain_id, seq_no in event_chains.items(): + chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0)) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* event ID. Events with a sequence less than that are in the + # auth chain. + if include_given: + results = initial_events + else: + results = set() + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + """ + + rows = txn.execute_values(sql, chains.items()) + results.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND sequence_number <= ? + """ + for chain_id, max_no in chains.items(): + txn.execute(sql, (chain_id, max_no)) + results.update(r for r, in txn) + + return list(results) + def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool ) -> List[str]: + """Calculates the auth chain IDs. + + This is used when we don't have a cover index for the room. + """ if include_given: results = set(event_ids) else: diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 06000f81a6..d597d712d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } - # Mark the room as not having a cover index + # Mark the room as maybe having a cover index. def store_room(txn): self.store.db_pool.simple_insert_txn( @@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) + return room_id + + @parameterized.expand([(True,), (False,)]) + def test_auth_chain_ids(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + + # a and b have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["a", "b"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + # d and e have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"])) + self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) + self.assertEqual(auth_chain_ids, ["k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) + self.assertEqual(auth_chain_ids, ["j"]) + + # j and k have no parents. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) + self.assertEqual(auth_chain_ids, []) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) + self.assertEqual(auth_chain_ids, []) + + # More complex input sequences. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["h", "i"]) + ) + self.assertCountEqual(auth_chain_ids, ["k", "j"]) + + # e gets returned even though include_given is false, but it is in the + # auth chain of b. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "e"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + # Test include_given. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["i"], include_given=True) + ) + self.assertCountEqual(auth_chain_ids, ["i", "j"]) + + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + # Now actually test that various combinations give the right result: difference = self.get_success( -- cgit 1.5.1 From 17cd48fe5171d50da4cb59db647b993168e7dfab Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 10 Mar 2021 17:42:51 +0200 Subject: Fix spam checker modules documentation example (#9580) Mention that parse_config must exist and note the check_media_file_for_spam method. --- changelog.d/9580.doc | 1 + docs/spam_checker.md | 10 ++++++++++ 2 files changed, 11 insertions(+) create mode 100644 changelog.d/9580.doc diff --git a/changelog.d/9580.doc b/changelog.d/9580.doc new file mode 100644 index 0000000000..f9c8b328b3 --- /dev/null +++ b/changelog.d/9580.doc @@ -0,0 +1 @@ +Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index e615ac9910..2020eb9006 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -14,6 +14,7 @@ The Python class is instantiated with two objects: * An instance of `synapse.module_api.ModuleApi`. It then implements methods which return a boolean to alter behavior in Synapse. +All the methods must be defined. There's a generic method for checking every event (`check_event_for_spam`), as well as some specific methods: @@ -24,6 +25,7 @@ well as some specific methods: * `user_may_publish_room` * `check_username_for_spam` * `check_registration_for_spam` +* `check_media_file_for_spam` The details of each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. @@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class. The `ModuleApi` class provides a way for the custom spam checker class to call back into the homeserver internals. +Additionally, a `parse_config` method is mandatory and receives the plugin config +dictionary. After parsing, It must return an object which will be +passed to `__init__` later. + ### Example ```python @@ -41,6 +47,10 @@ class ExampleSpamChecker: self.config = config self.api = api + @staticmethod + def parse_config(config): + return config + async def check_event_for_spam(self, foo): return False # allow all events -- cgit 1.5.1 From 1107214a1d93280076ad547603b4a1396d709ad0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 10 Mar 2021 18:15:03 +0000 Subject: Fix the auth provider on the logins metric (#9573) We either need to pass the auth provider over the replication api, or make sure we report the auth provider on the worker that received the request. I've gone with the latter. --- changelog.d/9573.feature | 1 + synapse/handlers/register.py | 46 +++++++++++++++++++++++++-------------- synapse/replication/http/login.py | 4 ++-- 3 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 changelog.d/9573.feature diff --git a/changelog.d/9573.feature b/changelog.d/9573.feature new file mode 100644 index 0000000000..5214b50d41 --- /dev/null +++ b/changelog.d/9573.feature @@ -0,0 +1 @@ +Add prometheus metrics for number of users successfully registering and logging in. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b66f8756b8..cd001e87c7 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,7 +16,7 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from prometheus_client import Counter @@ -82,6 +82,7 @@ class RegistrationHandler(BaseHandler): ) else: self.device_handler = hs.get_device_handler() + self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.session_lifetime @@ -678,17 +679,35 @@ class RegistrationHandler(BaseHandler): Returns: Tuple of device ID and access token """ + res = await self._register_device_client( + user_id=user_id, + device_id=device_id, + initial_display_name=initial_display_name, + is_guest=is_guest, + is_appservice_ghost=is_appservice_ghost, + ) - if self.hs.config.worker_app: - r = await self._register_device_client( - user_id=user_id, - device_id=device_id, - initial_display_name=initial_display_name, - is_guest=is_guest, - is_appservice_ghost=is_appservice_ghost, - ) - return r["device_id"], r["access_token"] + login_counter.labels( + guest=is_guest, + auth_provider=(auth_provider_id or ""), + ).inc() + + return res["device_id"], res["access_token"] + + async def register_device_inner( + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool = False, + is_appservice_ghost: bool = False, + ) -> Dict[str, str]: + """Helper for register_device + Does the bits that need doing on the main process. Not for use outside this + class and RegisterDeviceReplicationServlet. + """ + assert not self.hs.config.worker_app valid_until_ms = None if self.session_lifetime is not None: if is_guest: @@ -713,12 +732,7 @@ class RegistrationHandler(BaseHandler): is_appservice_ghost=is_appservice_ghost, ) - login_counter.labels( - guest=is_guest, - auth_provider=(auth_provider_id or ""), - ).inc() - - return (registered_device_id, access_token) + return {"device_id": registered_device_id, "access_token": access_token} async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 36071feb36..4ec1bfa6ea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] - device_id, access_token = await self.registration_handler.register_device( + res = await self.registration_handler.register_device_inner( user_id, device_id, initial_display_name, @@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_appservice_ghost=is_appservice_ghost, ) - return 200, {"device_id": device_id, "access_token": access_token} + return 200, res def register_servlets(hs, http_server): -- cgit 1.5.1 From a7a379006651ea219c2618c0a47e8f4053a9e769 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 10 Mar 2021 18:15:56 +0000 Subject: Convert Requester to attrs (#9586) ... because namedtuples suck Fix up a couple of other annotations to keep mypy happy. --- changelog.d/9586.misc | 1 + synapse/handlers/auth.py | 5 ++- synapse/rest/media/v1/media_repository.py | 3 +- synapse/storage/databases/main/registration.py | 6 +-- synapse/types.py | 57 +++++++++++++------------- 5 files changed, 37 insertions(+), 35 deletions(-) create mode 100644 changelog.d/9586.misc diff --git a/changelog.d/9586.misc b/changelog.d/9586.misc new file mode 100644 index 0000000000..2def9d5f55 --- /dev/null +++ b/changelog.d/9586.misc @@ -0,0 +1 @@ +Convert `synapse.types.Requester` to an `attrs` class. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bec0c615d4..fb5f8118f0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -337,7 +337,8 @@ class AuthHandler(BaseHandler): user is too high to proceed """ - + if not requester.access_token_id: + raise ValueError("Cannot validate a user without an access token") if self._ui_auth_session_timeout: last_validated = await self.store.get_access_token_last_validated( requester.access_token_id @@ -1213,7 +1214,7 @@ class AuthHandler(BaseHandler): async def delete_access_tokens_for_user( self, user_id: str, - except_token_id: Optional[str] = None, + except_token_id: Optional[int] = None, device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0641924f18..8b4841ed5d 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -35,6 +35,7 @@ from synapse.api.errors import ( from synapse.config._base import ConfigError from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -145,7 +146,7 @@ class MediaRepository: upload_name: Optional[str], content: IO, content_length: int, - auth_user: str, + auth_user: UserID, ) -> str: """Store uploaded content for a local user and return the mxc URL diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 61a7556e56..eba66ff352 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr @@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): async def user_delete_access_tokens( self, user_id: str, - except_token_id: Optional[str] = None, + except_token_id: Optional[int] = None, device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ @@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] + values = [v for _, v in items] # type: List[Union[str, int]] if except_token_id: where_clause += " AND id != ?" values.append(except_token_id) diff --git a/synapse/types.py b/synapse/types.py index 0216d213c7..b08ce90140 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -83,33 +83,32 @@ class ISynapseReactor( """The interfaces necessary for Synapse to function.""" -class Requester( - namedtuple( - "Requester", - [ - "user", - "access_token_id", - "is_guest", - "shadow_banned", - "device_id", - "app_service", - "authenticated_entity", - ], - ) -): +@attr.s(frozen=True, slots=True) +class Requester: """ Represents the user making a request Attributes: - user (UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this + user: id of the user making the request + access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - shadow_banned (bool): True if the user making this request has been shadow-banned. - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. """ + user = attr.ib(type="UserID") + access_token_id = attr.ib(type=Optional[int]) + is_guest = attr.ib(type=bool) + shadow_banned = attr.ib(type=bool) + device_id = attr.ib(type=Optional[str]) + app_service = attr.ib(type=Optional["ApplicationService"]) + authenticated_entity = attr.ib(type=str) + def serialize(self): """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -157,23 +156,23 @@ class Requester( def create_requester( user_id: Union[str, "UserID"], access_token_id: Optional[int] = None, - is_guest: Optional[bool] = False, - shadow_banned: Optional[bool] = False, + is_guest: bool = False, + shadow_banned: bool = False, device_id: Optional[str] = None, app_service: Optional["ApplicationService"] = None, authenticated_entity: Optional[str] = None, -): +) -> Requester: """ Create a new ``Requester`` object Args: - user_id (str|UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - shadow_banned (bool): True if the user making this request is shadow-banned. - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user authenticated_entity: The entity that authenticated when making the request. This is different to the user_id when an admin user or the server is "puppeting" the user. -- cgit 1.5.1 From 70d1b6abff125c0dd5022d1394548ee12367b371 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 11 Mar 2021 14:52:32 +0100 Subject: Re-Activating account when local passwords are disabled (#9587) Fixes: #8393 --- changelog.d/9587.bugfix | 1 + synapse/rest/admin/users.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9587.bugfix diff --git a/changelog.d/9587.bugfix b/changelog.d/9587.bugfix new file mode 100644 index 0000000000..d8f04c4f21 --- /dev/null +++ b/changelog.d/9587.bugfix @@ -0,0 +1 @@ +Re-Activating account with admin API when local passwords are disabled. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 267a993430..2c89b62e25 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet): target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: - if "password" not in body: + if ( + "password" not in body + and self.hs.config.password_localdb_enabled + ): raise SynapseError( 400, "Must provide a password to re-activate an account." ) -- cgit 1.5.1 From e55bd0e11012ee14a8b18ece16f4837b86212f5c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 11 Mar 2021 09:15:22 -0500 Subject: Add tests for blacklisting reactor/agent. (#9563) --- changelog.d/9563.misc | 1 + synapse/http/client.py | 26 +++++----- tests/http/test_client.py | 126 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 139 insertions(+), 14 deletions(-) create mode 100644 changelog.d/9563.misc diff --git a/changelog.d/9563.misc b/changelog.d/9563.misc new file mode 100644 index 0000000000..7a3493e4a1 --- /dev/null +++ b/changelog.d/9563.misc @@ -0,0 +1 @@ +Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. diff --git a/synapse/http/client.py b/synapse/http/client.py index af34d583ad..8f3da486b3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -39,6 +39,7 @@ from zope.interface import implementer, provider from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, error as twisted_error, protocol, ssl +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import ( IAddress, IHostResolution, @@ -151,16 +152,17 @@ class _IPBlacklistingResolver: def resolveHostName( self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 ) -> IResolutionReceiver: - - r = recv() addresses = [] # type: List[IAddress] def _callback() -> None: - r.resolutionBegan(None) - has_bad_ip = False - for i in addresses: - ip_address = IPAddress(i.host) + for address in addresses: + # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups + # should go through this path. + if not isinstance(address, (IPv4Address, IPv6Address)): + continue + + ip_address = IPAddress(address.host) if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist @@ -175,15 +177,15 @@ class _IPBlacklistingResolver: # request, but all we can really do from here is claim that there were no # valid results. if not has_bad_ip: - for i in addresses: - r.addressResolved(i) - r.resolutionComplete() + for address in addresses: + recv.addressResolved(address) + recv.resolutionComplete() @provider(IResolutionReceiver) class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress: IHostResolution) -> None: - pass + recv.resolutionBegan(resolutionInProgress) @staticmethod def addressResolved(address: IAddress) -> None: @@ -197,7 +199,7 @@ class _IPBlacklistingResolver: EndpointReceiver, hostname, portNumber=portNumber ) - return r + return recv @implementer(ISynapseReactor) @@ -346,7 +348,7 @@ class SimpleHttpClient: contextFactory=self.hs.get_http_client_context_factory(), pool=pool, use_proxy=use_proxy, - ) + ) # type: IAgent if self._ip_blacklist: # If we have an IP blacklist, we then install the blacklisting Agent diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 21ecb81c99..0ce181a51e 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -16,12 +16,23 @@ from io import BytesIO from mock import Mock +from netaddr import IPSet + +from twisted.internet.error import DNSLookupError from twisted.python.failure import Failure -from twisted.web.client import ResponseDone +from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.web.client import Agent, ResponseDone from twisted.web.iweb import UNKNOWN_LENGTH -from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size +from synapse.api.errors import SynapseError +from synapse.http.client import ( + BlacklistingAgentWrapper, + BlacklistingReactorWrapper, + BodyExceededMaxSize, + read_body_with_max_size, +) +from tests.server import FakeTransport, get_clock from tests.unittest import TestCase @@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase): # The data is never consumed. self.assertEqual(result.getvalue(), b"") + + +class BlacklistingAgentTest(TestCase): + def setUp(self): + self.reactor, self.clock = get_clock() + + self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" + self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8" + self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" + + # Configure the reactor's DNS resolver. + for (domain, ip) in ( + (self.safe_domain, self.safe_ip), + (self.unsafe_domain, self.unsafe_ip), + (self.allowed_domain, self.allowed_ip), + ): + self.reactor.lookups[domain.decode()] = ip.decode() + self.reactor.lookups[ip.decode()] = ip.decode() + + self.ip_whitelist = IPSet([self.allowed_ip.decode()]) + self.ip_blacklist = IPSet(["5.0.0.0/8"]) + + def test_reactor(self): + """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" + agent = Agent( + BlacklistingReactorWrapper( + self.reactor, + ip_whitelist=self.ip_whitelist, + ip_blacklist=self.ip_blacklist, + ), + ) + + # The unsafe domains and IPs should be rejected. + for domain in (self.unsafe_domain, self.unsafe_ip): + self.failureResultOf( + agent.request(b"GET", b"http://" + domain), DNSLookupError + ) + + # The safe domains IPs should be accepted. + for domain in ( + self.safe_domain, + self.allowed_domain, + self.safe_ip, + self.allowed_ip, + ): + d = agent.request(b"GET", b"http://" + domain) + + # Grab the latest TCP connection. + ( + host, + port, + client_factory, + _timeout, + _bindAddress, + ) = self.reactor.tcpClients[-1] + + # Make the connection and pump data through it. + client = client_factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" + ) + + response = self.successResultOf(d) + self.assertEqual(response.code, 200) + + def test_agent(self): + """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" + agent = BlacklistingAgentWrapper( + Agent(self.reactor), + ip_whitelist=self.ip_whitelist, + ip_blacklist=self.ip_blacklist, + ) + + # The unsafe IPs should be rejected. + self.failureResultOf( + agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError + ) + + # The safe and unsafe domains and safe IPs should be accepted. + for domain in ( + self.safe_domain, + self.unsafe_domain, + self.allowed_domain, + self.safe_ip, + self.allowed_ip, + ): + d = agent.request(b"GET", b"http://" + domain) + + # Grab the latest TCP connection. + ( + host, + port, + client_factory, + _timeout, + _bindAddress, + ) = self.reactor.tcpClients[-1] + + # Make the connection and pump data through it. + client = client_factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" + ) + + response = self.successResultOf(d) + self.assertEqual(response.code, 200) -- cgit 1.5.1 From 464e5da7b28f201375aec64de9821b8b232e3046 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 11 Mar 2021 18:35:09 +0000 Subject: Add logging for redis connection setup (#9590) --- changelog.d/9590.misc | 1 + stubs/txredisapi.pyi | 4 +++- synapse/replication/tcp/redis.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9590.misc diff --git a/changelog.d/9590.misc b/changelog.d/9590.misc new file mode 100644 index 0000000000..186396c45b --- /dev/null +++ b/changelog.d/9590.misc @@ -0,0 +1 @@ +Add logging for redis connection setup. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 618548a305..34787e0b1e 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -17,6 +17,8 @@ """ from typing import Any, List, Optional, Type, Union +from twisted.internet import protocol + class RedisProtocol: def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... @@ -52,7 +54,7 @@ def lazyConnection( class ConnectionHandler: ... -class RedisFactory: +class RedisFactory(protocol.ReconnectingClientFactory): continueTrying: bool handler: RedisProtocol pool: List[RedisProtocol] diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 7560706b4b..574eaea1eb 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -20,6 +20,10 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast import attr import txredisapi +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.interfaces import IAddress, IConnector +from twisted.python.failure import Failure + from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, @@ -253,6 +257,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory): except Exception: logger.warning("Failed to send ping to a redis connection") + # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but + # it's rubbish. We add our own here. + + def startedConnecting(self, connector: IConnector): + logger.info( + "Connecting to redis server %s", format_address(connector.getDestination()) + ) + super().startedConnecting(connector) + + def clientConnectionFailed(self, connector: IConnector, reason: Failure): + logger.info( + "Connection to redis server %s failed: %s", + format_address(connector.getDestination()), + reason.value, + ) + super().clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector: IConnector, reason: Failure): + logger.info( + "Connection to redis server %s lost: %s", + format_address(connector.getDestination()), + reason.value, + ) + super().clientConnectionLost(connector, reason) + + +def format_address(address: IAddress) -> str: + if isinstance(address, (IPv4Address, IPv6Address)): + return "%s:%i" % (address.host, address.port) + return str(address) + class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately -- cgit 1.5.1 From 2b328d7e0233bc331ce567e593b4c2f94a7d6e2e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 12 Mar 2021 15:08:03 +0000 Subject: Improve logging when processing incoming transactions (#9596) Put the room id in the logcontext, to make it easier to understand what's going on. --- changelog.d/9596.misc | 1 + synapse/federation/federation_server.py | 61 ++++++++++++++++++-------------- synapse/handlers/federation.py | 62 +++++++++------------------------ 3 files changed, 51 insertions(+), 73 deletions(-) create mode 100644 changelog.d/9596.misc diff --git a/changelog.d/9596.misc b/changelog.d/9596.misc new file mode 100644 index 0000000000..fc19a95f75 --- /dev/null +++ b/changelog.d/9596.misc @@ -0,0 +1 @@ +Improve logging when processing incoming transactions. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 06c5e7a9e0..89fe6def4b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -335,34 +335,41 @@ class FederationServer(FederationBase): # impose a limit to avoid going too crazy with ram/cpu. async def process_pdus_for_room(room_id: str): - logger.debug("Processing PDUs for %s", room_id) - try: - await self.check_server_matches_acl(origin_host, room_id) - except AuthError as e: - logger.warning("Ignoring PDUs for room %s from banned server", room_id) - for pdu in pdus_by_room[room_id]: - event_id = pdu.event_id - pdu_results[event_id] = e.error_dict() - return + with nested_logging_context(room_id): + logger.debug("Processing PDUs for %s", room_id) - for pdu in pdus_by_room[room_id]: - event_id = pdu.event_id - with pdu_process_time.time(): - with nested_logging_context(event_id): - try: - await self._handle_received_pdu(origin, pdu) - pdu_results[event_id] = {} - except FederationError as e: - logger.warning("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s", - event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore - ) + try: + await self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warning( + "Ignoring PDUs for room %s from banned server", room_id + ) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + + for pdu in pdus_by_room[room_id]: + pdu_results[pdu.event_id] = await process_pdu(pdu) + + async def process_pdu(pdu: EventBase) -> JsonDict: + event_id = pdu.event_id + with pdu_process_time.time(): + with nested_logging_context(event_id): + try: + await self._handle_received_pdu(origin, pdu) + return {} + except FederationError as e: + logger.warning("Error handling PDU %s: %s", event_id, e) + return {"error": str(e)} + except Exception as e: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + return {"error": str(e)} await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3fe02b7195..1d20c441f3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -201,7 +201,7 @@ class FederationHandler(BaseHandler): or pdu.internal_metadata.is_outlier() ) if already_seen: - logger.debug("[%s %s]: Already seen pdu", room_id, event_id) + logger.debug("Already seen pdu") return # do some initial sanity-checking of the event. In particular, make @@ -210,18 +210,14 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warning( - "[%s %s] Received event failed sanity checks", room_id, event_id - ) + logger.warning("Received event failed sanity checks") raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) # If we are currently in the process of joining this room, then we # queue up events for later processing. if room_id in self.room_queues: logger.info( - "[%s %s] Queuing PDU from %s for now: join in progress", - room_id, - event_id, + "Queuing PDU from %s for now: join in progress", origin, ) self.room_queues[room_id].append((pdu, origin)) @@ -236,9 +232,7 @@ class FederationHandler(BaseHandler): is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( - "[%s %s] Ignoring PDU from %s as we're not in the room", - room_id, - event_id, + "Ignoring PDU from %s as we're not in the room", origin, ) return None @@ -250,7 +244,7 @@ class FederationHandler(BaseHandler): # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) + logger.debug("min_depth: %d", min_depth) prevs = set(pdu.prev_event_ids()) seen = await self.store.have_events_in_timeline(prevs) @@ -267,17 +261,13 @@ class FederationHandler(BaseHandler): # If we're missing stuff, ensure we only fetch stuff one # at a time. logger.info( - "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", - room_id, - event_id, + "Acquiring room lock to fetch %d missing prev_events: %s", len(missing_prevs), shortstr(missing_prevs), ) with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "[%s %s] Acquired room lock to fetch %d missing prev_events", - room_id, - event_id, + "Acquired room lock to fetch %d missing prev_events", len(missing_prevs), ) @@ -297,9 +287,7 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( - "[%s %s] Found all missing prev_events", - room_id, - event_id, + "Found all missing prev_events", ) if prevs - seen: @@ -329,9 +317,7 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warning( - "[%s %s] Rejecting: failed to fetch %d prev events: %s", - room_id, - event_id, + "Rejecting: failed to fetch %d prev events: %s", len(prevs - seen), shortstr(prevs - seen), ) @@ -414,10 +400,7 @@ class FederationHandler(BaseHandler): state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( - "[%s %s] Error attempting to resolve state at missing " - "prev_events", - room_id, - event_id, + "Error attempting to resolve state at missing " "prev_events", exc_info=True, ) raise FederationError( @@ -454,9 +437,7 @@ class FederationHandler(BaseHandler): latest |= seen logger.info( - "[%s %s]: Requesting missing events between %s and %s", - room_id, - event_id, + "Requesting missing events between %s and %s", shortstr(latest), event_id, ) @@ -523,15 +504,11 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warning( - "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e - ) + logger.warning("Failed to get prev_events: %s", e) return logger.info( - "[%s %s]: Got %d prev_events: %s", - room_id, - event_id, + "Got %d prev_events: %s", len(missing_events), shortstr(missing_events), ) @@ -542,9 +519,7 @@ class FederationHandler(BaseHandler): for ev in missing_events: logger.info( - "[%s %s] Handling received prev_event %s", - room_id, - event_id, + "Handling received prev_event %s", ev.event_id, ) with nested_logging_context(ev.event_id): @@ -553,9 +528,7 @@ class FederationHandler(BaseHandler): except FederationError as e: if e.code == 403: logger.warning( - "[%s %s] Received prev_event %s failed history check.", - room_id, - event_id, + "Received prev_event %s failed history check.", ev.event_id, ) else: @@ -707,10 +680,7 @@ class FederationHandler(BaseHandler): (ie, we are missing one or more prev_events), the resolved state at the event """ - room_id = event.room_id - event_id = event.event_id - - logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) + logger.debug("Processing event: %s", event) try: await self._handle_new_event(origin, event, state=state) -- cgit 1.5.1 From 1e67bff8333bd358617eb90fa4155ad90b486f28 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 12 Mar 2021 15:14:55 +0000 Subject: Reject concurrent transactions (#9597) If more transactions arrive from an origin while we're still processing the first one, reject them. Hopefully a quick fix to https://github.com/matrix-org/synapse/issues/9489 --- changelog.d/9597.bugfix | 1 + synapse/federation/federation_server.py | 77 ++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 35 deletions(-) create mode 100644 changelog.d/9597.bugfix diff --git a/changelog.d/9597.bugfix b/changelog.d/9597.bugfix new file mode 100644 index 0000000000..349dc9d664 --- /dev/null +++ b/changelog.d/9597.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 89fe6def4b..db6e49dbca 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -112,10 +112,11 @@ class FederationServer(FederationBase): # with FederationHandlerRegistry. hs.get_directory_handler() - self._federation_ratelimiter = hs.get_federation_ratelimiter() - self._server_linearizer = Linearizer("fed_server") - self._transaction_linearizer = Linearizer("fed_txn_handler") + + # origins that we are currently processing a transaction from. + # a dict from origin to txn id. + self._active_transactions = {} # type: Dict[str, str] # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( @@ -169,6 +170,33 @@ class FederationServer(FederationBase): logger.debug("[%s] Got transaction", transaction_id) + # Reject malformed transactions early: reject if too many PDUs/EDUs + if len(transaction.pdus) > 50 or ( # type: ignore + hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore + ): + logger.info("Transaction PDU or EDU count too large. Returning 400") + return 400, {} + + # we only process one transaction from each origin at a time. We need to do + # this check here, rather than in _on_incoming_transaction_inner so that we + # don't cache the rejection in _transaction_resp_cache (so that if the txn + # arrives again later, we can process it). + current_transaction = self._active_transactions.get(origin) + if current_transaction and current_transaction != transaction_id: + logger.warning( + "Received another txn %s from %s while still processing %s", + transaction_id, + origin, + current_transaction, + ) + return 429, { + "errcode": Codes.UNKNOWN, + "error": "Too many concurrent transactions", + } + + # CRITICAL SECTION: we must now not await until we populate _active_transactions + # in _on_incoming_transaction_inner. + # We wrap in a ResponseCache so that we de-duplicate retried # transactions. return await self._transaction_resp_cache.wrap( @@ -182,26 +210,18 @@ class FederationServer(FederationBase): async def _on_incoming_transaction_inner( self, origin: str, transaction: Transaction, request_time: int ) -> Tuple[int, Dict[str, Any]]: - # Use a linearizer to ensure that transactions from a remote are - # processed in order. - with await self._transaction_linearizer.queue(origin): - # We rate limit here *after* we've queued up the incoming requests, - # so that we don't fill up the ratelimiter with blocked requests. - # - # This is important as the ratelimiter allows N concurrent requests - # at a time, and only starts ratelimiting if there are more requests - # than that being processed at a time. If we queued up requests in - # the linearizer/response cache *after* the ratelimiting then those - # queued up requests would count as part of the allowed limit of N - # concurrent requests. - with self._federation_ratelimiter.ratelimit(origin) as d: - await d - - result = await self._handle_incoming_transaction( - origin, transaction, request_time - ) + # CRITICAL SECTION: the first thing we must do (before awaiting) is + # add an entry to _active_transactions. + assert origin not in self._active_transactions + self._active_transactions[origin] = transaction.transaction_id # type: ignore - return result + try: + result = await self._handle_incoming_transaction( + origin, transaction, request_time + ) + return result + finally: + del self._active_transactions[origin] async def _handle_incoming_transaction( self, origin: str, transaction: Transaction, request_time: int @@ -227,19 +247,6 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore - # Reject if PDU count > 50 or EDU count > 100 - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): - - logger.info("Transaction PDU or EDU count too large. Returning 400") - - response = {} - await self.transaction_actions.set_response( - origin, transaction, 400, response - ) - return 400, response - # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. -- cgit 1.5.1 From 55da8df0782f80c43d127cd563cfbb89106319db Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Mar 2021 11:37:57 -0500 Subject: Fix additional type hints from Twisted 21.2.0. (#9591) --- changelog.d/9591.misc | 1 + synapse/api/auth.py | 2 +- synapse/federation/federation_server.py | 8 +- synapse/handlers/oidc_handler.py | 9 ++- synapse/http/client.py | 9 ++- synapse/logging/_remote.py | 23 ++++-- synapse/push/emailpusher.py | 4 +- synapse/replication/tcp/handler.py | 44 ++++++----- synapse/replication/tcp/protocol.py | 24 +++--- synapse/replication/tcp/redis.py | 8 +- synapse/rest/admin/_base.py | 15 ++-- synapse/rest/admin/media.py | 29 ++++--- synapse/rest/client/v2_alpha/groups.py | 105 ++++++++++++++++++-------- synapse/rest/media/v1/config_resource.py | 3 +- synapse/rest/media/v1/preview_url_resource.py | 3 +- synapse/rest/media/v1/upload_resource.py | 3 +- synapse/server.py | 8 +- tests/replication/test_federation_ack.py | 8 +- 18 files changed, 187 insertions(+), 119 deletions(-) create mode 100644 changelog.d/9591.misc diff --git a/changelog.d/9591.misc b/changelog.d/9591.misc new file mode 100644 index 0000000000..14c7b78dd9 --- /dev/null +++ b/changelog.d/9591.misc @@ -0,0 +1 @@ +Fix incorrect type hints. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 968cf6f174..e10e33fd23 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -164,7 +164,7 @@ class Auth: async def get_user_by_req( self, - request: Request, + request: SynapseRequest, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index db6e49dbca..9839d3d016 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -880,7 +880,9 @@ class FederationHandlerRegistry: self.edu_handlers = ( {} ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] - self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + self.query_handlers = ( + {} + ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] # Map from type to instance names that we should route EDU handling to. # We randomly choose one instance from the list to route to for each new @@ -914,7 +916,7 @@ class FederationHandlerRegistry: self.edu_handlers[edu_type] = handler def register_query_handler( - self, query_type: str, handler: Callable[[dict], defer.Deferred] + self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] ): """Sets the handler callable that will be used to handle an incoming federation query of the given type. @@ -987,7 +989,7 @@ class FederationHandlerRegistry: # Oh well, let's just log and move on. logger.warning("No handler registered for EDU type %s", edu_type) - async def on_query(self, query_type: str, args: dict): + async def on_query(self, query_type: str, args: dict) -> JsonDict: handler = self.query_handlers.get(query_type) if handler: return await handler(args) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 825fadb76f..f5d1821127 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -34,6 +34,7 @@ from pymacaroons.exceptions import ( from typing_extensions import TypedDict from twisted.web.client import readBody +from twisted.web.http_headers import Headers from synapse.config import ConfigError from synapse.config.oidc_config import ( @@ -538,7 +539,7 @@ class OidcProvider: """ metadata = await self.load_metadata() token_endpoint = metadata.get("token_endpoint") - headers = { + raw_headers = { "Content-Type": "application/x-www-form-urlencoded", "User-Agent": self._http_client.user_agent, "Accept": "application/json", @@ -552,10 +553,10 @@ class OidcProvider: body = urlencode(args, True) # Fill the body/headers with credentials - uri, headers, body = self._client_auth.prepare( - method="POST", uri=token_endpoint, headers=headers, body=body + uri, raw_headers, body = self._client_auth.prepare( + method="POST", uri=token_endpoint, headers=raw_headers, body=body ) - headers = {k: [v] for (k, v) in headers.items()} + headers = Headers({k: [v] for (k, v) in raw_headers.items()}) # Do the actual request # We're not using the SimpleHttpClient util methods as we don't want to diff --git a/synapse/http/client.py b/synapse/http/client.py index 8f3da486b3..d4ab3a2732 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -57,7 +57,13 @@ from twisted.web.client import ( ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse +from twisted.web.iweb import ( + UNKNOWN_LENGTH, + IAgent, + IBodyProducer, + IPolicyForHTTPS, + IResponse, +) from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri @@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by return query_str.encode("utf8") +@implementer(IPolicyForHTTPS) class InsecureInterceptableContextFactory(ssl.ContextFactory): """ Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain. diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 174ca7be5a..643492ceaf 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -32,8 +32,9 @@ from twisted.internet.endpoints import ( TCP4ClientEndpoint, TCP6ClientEndpoint, ) -from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport +from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint from twisted.internet.protocol import Factory, Protocol +from twisted.internet.tcp import Connection from twisted.python.failure import Failure logger = logging.getLogger(__name__) @@ -52,7 +53,9 @@ class LogProducer: format: A callable to format the log record to a string. """ - transport = attr.ib(type=ITransport) + # This is essentially ITCPTransport, but that is missing certain fields + # (connected and registerProducer) which are part of the implementation. + transport = attr.ib(type=Connection) _format = attr.ib(type=Callable[[logging.LogRecord], str]) _buffer = attr.ib(type=deque) _paused = attr.ib(default=False, type=bool, init=False) @@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler): if self._connection_waiter: return - self._connection_waiter = self._service.whenConnected(failAfterFailures=1) - def fail(failure: Failure) -> None: # If the Deferred was cancelled (e.g. during shutdown) do not try to # reconnect (this will cause an infinite loop of errors). @@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler): self._connect() def writer(result: Protocol) -> None: + # Force recognising transport as a Connection and not the more + # generic ITransport. + transport = result.transport # type: Connection # type: ignore + # We have a connection. If we already have a producer, and its # transport is the same, just trigger a resumeProducing. - if self._producer and result.transport is self._producer.transport: + if self._producer and transport is self._producer.transport: self._producer.resumeProducing() self._connection_waiter = None return @@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler): # Make a new producer and start it. self._producer = LogProducer( buffer=self._buffer, - transport=result.transport, + transport=transport, format=self.format, ) - result.transport.registerProducer(self._producer, True) + transport.registerProducer(self._producer, True) self._producer.resumeProducing() self._connection_waiter = None - self._connection_waiter.addCallbacks(writer, fail) + deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred + deferred.addCallbacks(writer, fail) + self._connection_waiter = deferred def _handle_pressure(self) -> None: """ diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 5fec2aaf5d..3dc06a79e8 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -16,8 +16,8 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional -from twisted.internet.base import DelayedCall from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, ThrottleParams @@ -66,7 +66,7 @@ class EmailPusher(Pusher): self.store = self.hs.get_datastore() self.email = pusher_config.pushkey - self.timed_call = None # type: Optional[DelayedCall] + self.timed_call = None # type: Optional[IDelayedCall] self.throttle_params = {} # type: Dict[str, ThrottleParams] self._inited = False diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index a7245da152..ee909f3fc5 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import ( UserIpCommand, UserSyncCommand, ) -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, AccountDataStream, @@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache" # the type of the entries in _command_queues_by_stream _StreamCommandQueue = Deque[ - Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] + Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection] ] @@ -174,7 +174,7 @@ class ReplicationCommandHandler: # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) - self._connections = [] # type: List[AbstractConnection] + self._connections = [] # type: List[IReplicationConnection] LaterGauge( "synapse_replication_tcp_resource_total_connections", @@ -197,7 +197,7 @@ class ReplicationCommandHandler: # For each connection, the incoming stream names that have received a POSITION # from that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] LaterGauge( "synapse_replication_tcp_command_queue", @@ -220,7 +220,7 @@ class ReplicationCommandHandler: self._server_notices_sender = hs.get_server_notices_sender() def _add_command_to_stream_queue( - self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] ) -> None: """Queue the given received command for processing @@ -267,7 +267,7 @@ class ReplicationCommandHandler: async def _process_command( self, cmd: Union[PositionCommand, RdataCommand], - conn: AbstractConnection, + conn: IReplicationConnection, stream_name: str, ) -> None: if isinstance(cmd, PositionCommand): @@ -321,10 +321,10 @@ class ReplicationCommandHandler: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate - def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) - def send_positions_to_connection(self, conn: AbstractConnection): + def send_positions_to_connection(self, conn: IReplicationConnection): """Send current position of all streams this process is source of to the connection. """ @@ -347,7 +347,7 @@ class ReplicationCommandHandler: ) def on_USER_SYNC( - self, conn: AbstractConnection, cmd: UserSyncCommand + self, conn: IReplicationConnection, cmd: UserSyncCommand ) -> Optional[Awaitable[None]]: user_sync_counter.inc() @@ -359,21 +359,23 @@ class ReplicationCommandHandler: return None def on_CLEAR_USER_SYNC( - self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand ) -> Optional[Awaitable[None]]: if self._is_master: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: return None - def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): + def on_FEDERATION_ACK( + self, conn: IReplicationConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) def on_USER_IP( - self, conn: AbstractConnection, cmd: UserIpCommand + self, conn: IReplicationConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() @@ -395,7 +397,7 @@ class ReplicationCommandHandler: assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) - def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -412,7 +414,7 @@ class ReplicationCommandHandler: self._add_command_to_stream_queue(conn, cmd) async def _process_rdata( - self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand ) -> None: """Process an RDATA command @@ -486,7 +488,7 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return @@ -496,7 +498,7 @@ class ReplicationCommandHandler: self._add_command_to_stream_queue(conn, cmd) async def _process_position( - self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand ) -> None: """Process a POSITION command @@ -553,7 +555,9 @@ class ReplicationCommandHandler: self._streams_by_connection.setdefault(conn, set()).add(stream_name) - def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): + def on_REMOTE_SERVER_UP( + self, conn: IReplicationConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -576,7 +580,7 @@ class ReplicationCommandHandler: # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: AbstractConnection): + def new_connection(self, connection: IReplicationConnection): """Called when we have a new connection.""" self._connections.append(connection) @@ -603,7 +607,7 @@ class ReplicationCommandHandler: UserSyncCommand(self._instance_id, user_id, True, now) ) - def lost_connection(self, connection: AbstractConnection): + def lost_connection(self, connection: IReplicationConnection): """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) @@ -624,7 +628,7 @@ class ReplicationCommandHandler: return bool(self._connections) def send_command( - self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None ): """Send a command to all connected connections. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e0b4ad314d..8e4734b59c 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct @@ -54,6 +53,7 @@ from inspect import isawaitable from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from zope.interface import Interface, implementer from twisted.internet import task from twisted.protocols.basic import LineOnlyReceiver @@ -121,6 +121,14 @@ class ConnectionStates: CLOSED = "closed" +class IReplicationConnection(Interface): + """An interface for replication connections.""" + + def send_command(cmd: Command): + """Send the command down the connection""" + + +@implementer(IReplicationConnection) class BaseReplicationStreamProtocol(LineOnlyReceiver): """Base replication protocol shared between client and server. @@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) -class AbstractConnection(abc.ABC): - """An interface for replication connections.""" - - @abc.abstractmethod - def send_command(self, cmd: Command): - """Send the command down the connection""" - pass - - -# This tells python that `BaseReplicationStreamProtocol` implements the -# interface. -AbstractConnection.register(BaseReplicationStreamProtocol) - - # The following simply registers metrics for the replication connections pending_commands = LaterGauge( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 574eaea1eb..7cccde097d 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast import attr import txredisapi +from zope.interface import implementer from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import IAddress, IConnector @@ -36,7 +37,7 @@ from synapse.replication.tcp.commands import ( parse_command_from_line, ) from synapse.replication.tcp.protocol import ( - AbstractConnection, + IReplicationConnection, tcp_inbound_commands_counter, tcp_outbound_commands_counter, ) @@ -66,7 +67,8 @@ class ConstantProperty(Generic[T, V]): pass -class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): +@implementer(IReplicationConnection) +class RedisSubscriber(txredisapi.SubscriberProtocol): """Connection to redis subscribed to replication stream. This class fulfils two functions: @@ -75,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): connection, parsing *incoming* messages into replication commands, and passing them to `ReplicationCommandHandler` - (b) it implements the AbstractConnection API, where it sends *outgoing* commands + (b) it implements the IReplicationConnection API, where it sends *outgoing* commands onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index e09234c644..7681e55b58 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -15,10 +15,9 @@ import re -import twisted.web.server - -import synapse.api.auth +from synapse.api.auth import Auth from synapse.api.errors import AuthError +from synapse.http.site import SynapseRequest from synapse.types import UserID @@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"): return patterns -async def assert_requester_is_admin( - auth: synapse.api.auth.Auth, request: twisted.web.server.Request -) -> None: +async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None: """Verify that the requester is an admin user Args: - auth: api.auth.Auth singleton + auth: Auth singleton request: incoming request Raises: @@ -53,11 +50,11 @@ async def assert_requester_is_admin( await assert_user_is_admin(auth, requester.user) -async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None: +async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """Verify that the given user is an admin user Args: - auth: api.auth.Auth singleton + auth: Auth singleton user_id: user to check Raises: diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 511c859f64..7fcc48a9d7 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,10 +17,9 @@ import logging from typing import TYPE_CHECKING, Tuple -from twisted.web.server import Request - from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, @@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet): self.auth = hs.get_auth() async def on_POST( - self, request: Request, server_name: str, media_id: str + self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet): self.media_repository = hs.get_media_repository() async def on_DELETE( - self, request: Request, server_name: str, media_id: str + self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet): self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, server_name: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 7aea4cebf5..5901432fad 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -32,6 +32,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.types import GroupID, JsonDict from ._base import client_patterns @@ -70,7 +71,9 @@ class GroupServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -81,7 +84,9 @@ class GroupServlet(RestServlet): return 200, group_description @_validate_group_id - async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, category_id: Optional[str], room_id: str + self, + request: SynapseRequest, + group_id: str, + category_id: Optional[str], + room_id: str, ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, category_id: str, room_id: str + self, request: SynapseRequest, group_id: str, category_id: str, room_id: str ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet): @_validate_group_id async def on_GET( - self, request: Request, group_id: str, category_id: str + self, request: SynapseRequest, group_id: str, category_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, category_id: str + self, request: SynapseRequest, group_id: str, category_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet): @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, category_id: str + self, request: SynapseRequest, group_id: str, category_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet): @_validate_group_id async def on_GET( - self, request: Request, group_id: str, role_id: str + self, request: SynapseRequest, group_id: str, role_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, role_id: str + self, request: SynapseRequest, group_id: str, role_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet): @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, role_id: str + self, request: SynapseRequest, group_id: str, role_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, role_id: Optional[str], user_id: str + self, + request: SynapseRequest, + group_id: str, + role_id: Optional[str], + user_id: str, ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, role_id: str, user_id: str + self, request: SynapseRequest, group_id: str, role_id: str, user_id: str ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, room_id: str + self, request: SynapseRequest, group_id: str, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet): @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, room_id: str + self, request: SynapseRequest, group_id: str, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, room_id: str, config_key: str + self, request: SynapseRequest, group_id: str, room_id: str, config_key: str ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet): self.is_mine_id = hs.is_mine_id @_validate_group_id - async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.store = hs.get_datastore() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) result = await self.groups_handler.get_publicised_groups_for_user(user_id) @@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 9039662f7e..1eff98ef14 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from twisted.web.server import Request from synapse.http.server import DirectServeJsonResource, respond_with_json +from synapse.http.site import SynapseRequest if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource): self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a074e807dc..b8895aeaa9 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -39,6 +39,7 @@ from synapse.http.server import ( respond_with_json_bytes, ) from synapse.http.servlet import parse_integer, parse_string +from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers @@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource): request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 5e104fac40..ae5aef2f7f 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -22,6 +22,7 @@ from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: @@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource): async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point diff --git a/synapse/server.py b/synapse/server.py index 369cc88026..48ac87a124 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_http_client_context_factory(self) -> IPolicyForHTTPS: - return ( - InsecureInterceptableContextFactory() - if self.config.use_insecure_ssl_client_just_for_testing_do_not_use - else RegularPolicyForHTTPS() - ) + if self.config.use_insecure_ssl_client_just_for_testing_do_not_use: + return InsecureInterceptableContextFactory() + return RegularPolicyForHTTPS() @cache_in_self def get_simple_http_client(self) -> SimpleHttpClient: diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index f235f1bd83..0d9e3bb11d 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -17,7 +17,7 @@ import mock from synapse.app.generic_worker import GenericWorkerServer from synapse.replication.tcp.commands import FederationAckCommand -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams.federation import FederationStream from tests.unittest import HomeserverTestCase @@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase): """ rch = self.hs.get_tcp_replication() - # wire up the ReplicationCommandHandler to a mock connection - mock_connection = mock.Mock(spec=AbstractConnection) + # wire up the ReplicationCommandHandler to a mock connection, which needs + # to implement IReplicationConnection. (Note that Mock doesn't understand + # interfaces, but casing an interface to a list gives the attributes.) + mock_connection = mock.Mock(spec=list(IReplicationConnection)) rch.new_connection(mock_connection) # tell it it received an RDATA row -- cgit 1.5.1 From af2248f8bf1cf11a230577650e84885387f1920f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 15 Mar 2021 13:51:02 +0000 Subject: Optimise missing prev_event handling (#9601) Background: When we receive incoming federation traffic, and notice that we are missing prev_events from the incoming traffic, first we do a `/get_missing_events` request, and then if we still have missing prev_events, we set up new backwards-extremities. To do that, we need to make a `/state_ids` request to ask the remote server for the state at those prev_events, and then we may need to then ask the remote server for any events in that state which we don't already have, as well as the auth events for those missing state events, so that we can auth them. This PR attempts to optimise the processing of that state request. The `state_ids` API returns a list of the state events, as well as a list of all the auth events for *all* of those state events. The optimisation comes from the observation that we are currently loading all of those auth events into memory at the start of the operation, but we almost certainly aren't going to need *all* of the auth events. Rather, we can check that we have them, and leave the actual load into memory for later. (Ideally the federation API would tell us which auth events we're actually going to need, but it doesn't.) The effect of this is to reduce the number of events that I need to load for an event in Matrix HQ from about 60000 to about 22000, which means it can stay in my in-memory cache, whereas previously the sheer number of events meant that all 60K events had to be loaded from db for each request, due to the amount of cache churn. (NB I've already tripled the size of the cache from its default of 10K). Unfortunately I've ended up basically C&Ping `_get_state_for_room` and `_get_events_from_store_or_dest` into a new method, because `_get_state_for_room` is also called during backfill, which expects the auth events to be returned, so the same tricks don't work. That said, I don't really know why that codepath is completely different (ultimately we're doing the same thing in setting up a new backwards extremity) so I've left a TODO suggesting that we clean it up. --- changelog.d/9601.feature | 1 + synapse/handlers/federation.py | 152 ++++++++++++++++++++---- synapse/storage/databases/main/events_worker.py | 12 +- 3 files changed, 137 insertions(+), 28 deletions(-) create mode 100644 changelog.d/9601.feature diff --git a/changelog.d/9601.feature b/changelog.d/9601.feature new file mode 100644 index 0000000000..5078d63ffa --- /dev/null +++ b/changelog.d/9601.feature @@ -0,0 +1 @@ +Optimise handling of incomplete room history for incoming federation. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1d20c441f3..598a66f74c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -353,17 +353,16 @@ class FederationHandler(BaseHandler): # Ask the remote server for the states we don't # know about for p in prevs - seen: - logger.info( - "Requesting state at missing prev_event %s", - event_id, - ) + logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = await self._get_state_for_room( - origin, room_id, p, include_event_in_state=True + remote_state = ( + await self._get_state_after_missing_prev_event( + origin, room_id, p + ) ) remote_state_map = { @@ -539,7 +538,6 @@ class FederationHandler(BaseHandler): destination: str, room_id: str, event_id: str, - include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. @@ -547,11 +545,9 @@ class FederationHandler(BaseHandler): destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. - include_event_in_state: if true, the event itself will be included in the - returned state event list. Returns: - A list of events in the state, possibly including the event itself, and + A list of events in the state, not including the event itself, and a list of events in the auth chain for the given event. """ ( @@ -563,9 +559,6 @@ class FederationHandler(BaseHandler): desired_events = set(state_event_ids + auth_event_ids) - if include_event_in_state: - desired_events.add(event_id) - event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -582,13 +575,6 @@ class FederationHandler(BaseHandler): event_map[e_id] for e_id in state_event_ids if e_id in event_map ] - if include_event_in_state: - remote_event = event_map.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) @@ -662,6 +648,131 @@ class FederationHandler(BaseHandler): return fetched_events + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + # TODO: This function is basically the same as _get_state_for_room. Can + # we make backfill() use it, rather than having two code paths? I think the + # only difference is that backfill() persists the prev events separately. + + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_desired_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + async def _process_received_pdu( self, origin: str, @@ -841,7 +952,6 @@ class FederationHandler(BaseHandler): destination=dest, room_id=room_id, event_id=e_id, - include_event_in_state=False, ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index edbe42f2bf..c04e162ccc 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools + import logging import threading from collections import namedtuple @@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore): Returns: set[str]: The events we have already seen. """ - results = set() + # if the event cache contains the event, obviously we've seen it. + results = {x for x in event_ids if self._get_event_cache.contains(x)} def have_seen_events_txn(txn, chunk): sql = "SELECT event_id FROM events as e WHERE " @@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore): txn.database_engine, "e.event_id", chunk ) txn.execute(sql + clause, args) - for (event_id,) in txn: - results.add(event_id) + results.update(row[0] for row in txn) - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + for chunk in batch_iter((x for x in event_ids if x not in results), 100): await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) -- cgit 1.5.1 From 026503fa3b90c03996a64be95432e345434b4a82 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Mar 2021 14:42:40 +0000 Subject: Don't go into federation catch up mode so easily (#9561) Federation catch up mode is very inefficient if the number of events that the remote server has missed is small, since handling gaps can be very expensive, c.f. #9492. Instead of going into catch up mode whenever we see an error, we instead do so only if we've backed off from trying the remote for more than an hour (the assumption being that in such a case it is more than a transient failure). --- changelog.d/9561.misc | 1 + synapse/federation/sender/per_destination_queue.py | 287 ++++++++++++--------- synapse/federation/sender/transaction_manager.py | 48 +--- synapse/storage/databases/main/transactions.py | 10 +- tests/federation/test_federation_catch_up.py | 3 +- 5 files changed, 190 insertions(+), 159 deletions(-) create mode 100644 changelog.d/9561.misc diff --git a/changelog.d/9561.misc b/changelog.d/9561.misc new file mode 100644 index 0000000000..6c529a82ee --- /dev/null +++ b/changelog.d/9561.misc @@ -0,0 +1 @@ +Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index deb519f3ef..cc0d765e5f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -17,6 +17,7 @@ import datetime import logging from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +import attr from prometheus_client import Counter from synapse.api.errors import ( @@ -93,6 +94,10 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False + # Flag to signal to any running transmission loop that there is new data + # queued up to be sent. + self._new_data_to_send = False + # True whilst we are sending events that the remote homeserver missed # because it was unreachable. We start in this state so we can perform # catch-up at startup. @@ -108,7 +113,7 @@ class PerDestinationQueue: # destination (we are the only updater so this is safe) self._last_successful_stream_ordering = None # type: Optional[int] - # a list of pending PDUs + # a queue of pending PDUs self._pending_pdus = [] # type: List[EventBase] # XXX this is never actually used: see @@ -208,6 +213,10 @@ class PerDestinationQueue: transaction in the background. """ + # Mark that we (may) have new things to send, so that any running + # transmission loop will recheck whether there is stuff to send. + self._new_data_to_send = True + if self.transmission_loop_running: # XXX: this can get stuck on by a never-ending # request at which point pending_pdus just keeps growing. @@ -250,125 +259,41 @@ class PerDestinationQueue: pending_pdus = [] while True: - # We have to keep 2 free slots for presence and rr_edus - limit = MAX_EDUS_PER_TRANSACTION - 2 - - device_update_edus, dev_list_id = await self._get_device_update_edus( - limit - ) - - limit -= len(device_update_edus) - - ( - to_device_edus, - device_stream_id, - ) = await self._get_to_device_message_edus(limit) - - pending_edus = device_update_edus + to_device_edus - - # BEGIN CRITICAL SECTION - # - # In order to avoid a race condition, we need to make sure that - # the following code (from popping the queues up to the point - # where we decide if we actually have any pending messages) is - # atomic - otherwise new PDUs or EDUs might arrive in the - # meantime, but not get sent because we hold the - # transmission_loop_running flag. - - pending_pdus = self._pending_pdus + self._new_data_to_send = False - # We can only include at most 50 PDUs per transactions - pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] + async with _TransactionQueueManager(self) as ( + pending_pdus, + pending_edus, + ): + if not pending_pdus and not pending_edus: + logger.debug("TX [%s] Nothing to send", self._destination) + + # If we've gotten told about new things to send during + # checking for things to send, we try looking again. + # Otherwise new PDUs or EDUs might arrive in the meantime, + # but not get sent because we hold the + # `transmission_loop_running` flag. + if self._new_data_to_send: + continue + else: + return - pending_edus.extend(self._get_rr_edus(force_flush=False)) - pending_presence = self._pending_presence - self._pending_presence = {} - if pending_presence: - pending_edus.append( - Edu( - origin=self._server_name, - destination=self._destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self._clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, + if pending_pdus: + logger.debug( + "TX [%s] len(pending_pdus_by_dest[dest]) = %d", + self._destination, + len(pending_pdus), ) - ) - pending_edus.extend( - self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self._pending_edus_keyed - ): - _, val = self._pending_edus_keyed.popitem() - pending_edus.append(val) - - if pending_pdus: - logger.debug( - "TX [%s] len(pending_pdus_by_dest[dest]) = %d", - self._destination, - len(pending_pdus), + await self._transaction_manager.send_new_transaction( + self._destination, pending_pdus, pending_edus ) - if not pending_pdus and not pending_edus: - logger.debug("TX [%s] Nothing to send", self._destination) - self._last_device_stream_id = device_stream_id - return - - # if we've decided to send a transaction anyway, and we have room, we - # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self._get_rr_edus(force_flush=True)) - - # END CRITICAL SECTION - - success = await self._transaction_manager.send_new_transaction( - self._destination, pending_pdus, pending_edus - ) - if success: sent_transactions_counter.inc() sent_edus_counter.inc(len(pending_edus)) for edu in pending_edus: sent_edus_by_type.labels(edu.edu_type).inc() - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if to_device_edus: - await self._store.delete_device_msgs_for_remote( - self._destination, device_stream_id - ) - # also mark the device updates as sent - if device_update_edus: - logger.info( - "Marking as sent %r %r", self._destination, dev_list_id - ) - await self._store.mark_as_sent_devices_by_remote( - self._destination, dev_list_id - ) - - self._last_device_stream_id = device_stream_id - self._last_device_list_stream_id = dev_list_id - - if pending_pdus: - # we sent some PDUs and it was successful, so update our - # last_successful_stream_ordering in the destinations table. - final_pdu = pending_pdus[-1] - last_successful_stream_ordering = ( - final_pdu.internal_metadata.stream_ordering - ) - assert last_successful_stream_ordering - await self._store.set_destination_last_successful_stream_ordering( - self._destination, last_successful_stream_ordering - ) - else: - break except NotRetryingDestination as e: logger.debug( "TX [%s] not ready for retry yet (next retry at %s) - " @@ -401,7 +326,7 @@ class PerDestinationQueue: self._pending_presence = {} self._pending_rrs = {} - self._start_catching_up() + self._start_catching_up() except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -412,7 +337,6 @@ class PerDestinationQueue: e, ) - self._start_catching_up() except RequestSendFailed as e: logger.warning( "TX [%s] Failed to send transaction: %s", self._destination, e @@ -422,16 +346,12 @@ class PerDestinationQueue: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False @@ -499,13 +419,10 @@ class PerDestinationQueue: rooms = [p.room_id for p in catchup_pdus] logger.info("Catching up rooms to %s: %r", self._destination, rooms) - success = await self._transaction_manager.send_new_transaction( + await self._transaction_manager.send_new_transaction( self._destination, catchup_pdus, [] ) - if not success: - return - sent_transactions_counter.inc() final_pdu = catchup_pdus[-1] self._last_successful_stream_ordering = cast( @@ -584,3 +501,135 @@ class PerDestinationQueue: """ self._catching_up = True self._pending_pdus = [] + + +@attr.s(slots=True) +class _TransactionQueueManager: + """A helper async context manager for pulling stuff off the queues and + tracking what was last successfully sent, etc. + """ + + queue = attr.ib(type=PerDestinationQueue) + + _device_stream_id = attr.ib(type=Optional[int], default=None) + _device_list_id = attr.ib(type=Optional[int], default=None) + _last_stream_ordering = attr.ib(type=Optional[int], default=None) + _pdus = attr.ib(type=List[EventBase], factory=list) + + async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: + # First we calculate the EDUs we want to send, if any. + + # We start by fetching device related EDUs, i.e device updates and to + # device messages. We have to keep 2 free slots for presence and rr_edus. + limit = MAX_EDUS_PER_TRANSACTION - 2 + + device_update_edus, dev_list_id = await self.queue._get_device_update_edus( + limit + ) + + if device_update_edus: + self._device_list_id = dev_list_id + else: + self.queue._last_device_list_stream_id = dev_list_id + + limit -= len(device_update_edus) + + ( + to_device_edus, + device_stream_id, + ) = await self.queue._get_to_device_message_edus(limit) + + if to_device_edus: + self._device_stream_id = device_stream_id + else: + self.queue._last_device_stream_id = device_stream_id + + pending_edus = device_update_edus + to_device_edus + + # Now add the read receipt EDU. + pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) + + # And presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} + + # Finally add any other types of EDUs if there is room. + pending_edus.extend( + self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) + ) + while ( + len(pending_edus) < MAX_EDUS_PER_TRANSACTION + and self.queue._pending_edus_keyed + ): + _, val = self.queue._pending_edus_keyed.popitem() + pending_edus.append(val) + + # Now we look for any PDUs to send, by getting up to 50 PDUs from the + # queue + self._pdus = self.queue._pending_pdus[:50] + + if not self._pdus and not pending_edus: + return [], [] + + # if we've decided to send a transaction anyway, and we have room, we + # may as well send any pending RRs + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: + pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + + if self._pdus: + self._last_stream_ordering = self._pdus[ + -1 + ].internal_metadata.stream_ordering + assert self._last_stream_ordering + + return self._pdus, pending_edus + + async def __aexit__(self, exc_type, exc, tb): + if exc_type is not None: + # Failed to send transaction, so we bail out. + return + + # Successfully sent transactions, so we remove pending PDUs from the queue + if self._pdus: + self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :] + + # Succeeded to send the transaction so we record where we have sent up + # to in the various streams + + if self._device_stream_id: + await self.queue._store.delete_device_msgs_for_remote( + self.queue._destination, self._device_stream_id + ) + self.queue._last_device_stream_id = self._device_stream_id + + # also mark the device updates as sent + if self._device_list_id: + logger.info( + "Marking as sent %r %r", self.queue._destination, self._device_list_id + ) + await self.queue._store.mark_as_sent_devices_by_remote( + self.queue._destination, self._device_list_id + ) + self.queue._last_device_list_stream_id = self._device_list_id + + if self._last_stream_ordering: + # we sent some PDUs and it was successful, so update our + # last_successful_stream_ordering in the destinations table. + await self.queue._store.set_destination_last_successful_stream_ordering( + self.queue._destination, self._last_stream_ordering + ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 2a9cd063c4..07b740c2f2 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -69,15 +69,12 @@ class TransactionManager: destination: str, pdus: List[EventBase], edus: List[Edu], - ) -> bool: + ) -> None: """ Args: destination: The destination to send to (e.g. 'example.org') pdus: In-order list of PDUs to send edus: List of EDUs to send - - Returns: - True iff the transaction was successful """ # Make a transaction-sending opentracing span. This span follows on from @@ -96,8 +93,6 @@ class TransactionManager: edu.strip_context() with start_active_span_follows_from("send_transaction", span_contexts): - success = True - logger.debug("TX [%s] _attempt_new_transaction", destination) txn_id = str(self._next_txn_id) @@ -152,44 +147,29 @@ class TransactionManager: response = await self._transport_layer.send_transaction( transaction, json_data_cb ) - code = 200 except HttpResponseException as e: code = e.code response = e.response - if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", destination, txn_id, code - ) - raise e - - logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - - if code == 200: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warning( - "TX [%s] {%s} Remote returned error for %s: %s", - destination, - txn_id, - e_id, - r, - ) - else: - for p in pdus: + set_tag(tags.ERROR, True) + + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) + raise + + logger.info("TX [%s] {%s} got 200 response", destination, txn_id) + + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: logger.warning( - "TX [%s] {%s} Failed to send event %s", + "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, - p.event_id, + e_id, + r, ) - success = False - if success and pdus and destination in self._federation_metrics_domains: + if pdus and destination in self._federation_metrics_domains: last_pdu = pdus[-1] last_pdu_ts_metric.labels(server_name=destination).set( last_pdu.origin_server_ts / 1000 ) - - set_tag(tags.ERROR, not success) - return success diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index b921d63d30..0309661841 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore): self.db_pool.simple_upsert_many_txn( txn, - "destination_rooms", - ["destination", "room_id"], - rows, - ["stream_ordering"], - [(stream_ordering,)] * len(rows), + table="destination_rooms", + key_names=("destination", "room_id"), + key_values=rows, + value_names=["stream_ordering"], + value_values=[(stream_ordering,)] * len(rows), ) async def get_destination_last_successful_stream_ordering( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 1a3ccb263d..6f96cd7940 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable from tests.unittest import FederatingHomeserverTestCase, override_config @@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): else: data = json_cb() self.failed_pdus.extend(data["pdus"]) - raise IOError("Failed to connect because this is a test!") + raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination) def get_destination_room(self, room: str, destination: str = "host2") -> dict: """ -- cgit 1.5.1 From d29b71aa50f568c3b951b95f2ea3009292b95630 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Mar 2021 11:14:39 -0400 Subject: Fix remaining mypy issues due to Twisted upgrade. (#9608) --- changelog.d/9608.misc | 1 + stubs/txredisapi.pyi | 2 +- synapse/http/client.py | 12 ++++++++-- synapse/replication/tcp/handler.py | 4 ++-- synapse/replication/tcp/protocol.py | 9 ++++++++ synapse/replication/tcp/redis.py | 2 +- tests/replication/_base.py | 44 ++++++++++++++----------------------- tests/server.py | 2 ++ 8 files changed, 42 insertions(+), 34 deletions(-) create mode 100644 changelog.d/9608.misc diff --git a/changelog.d/9608.misc b/changelog.d/9608.misc new file mode 100644 index 0000000000..14c7b78dd9 --- /dev/null +++ b/changelog.d/9608.misc @@ -0,0 +1 @@ +Fix incorrect type hints. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 34787e0b1e..080ca40287 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union from twisted.internet import protocol -class RedisProtocol: +class RedisProtocol(protocol.Protocol): def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... async def set( diff --git a/synapse/http/client.py b/synapse/http/client.py index d4ab3a2732..1e01e0a9f2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,7 +45,9 @@ from twisted.internet.interfaces import ( IHostResolution, IReactorPluggableNameResolver, IResolutionReceiver, + ITCPTransport, ) +from twisted.internet.protocol import connectionDone from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone @@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which immediately errors upon receiving data.""" + transport = None # type: Optional[ITCPTransport] + def __init__(self, deferred: defer.Deferred): self.deferred = deferred @@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): self.deferred.errback(BodyExceededMaxSize()) # Close the connection (forcefully) since all the data will get # discarded anyway. + assert self.transport is not None self.transport.abortConnection() def dataReceived(self, data: bytes) -> None: self._maybe_fail() - def connectionLost(self, reason: Failure) -> None: + def connectionLost(self, reason: Failure = connectionDone) -> None: self._maybe_fail() class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" + transport = None # type: Optional[ITCPTransport] + def __init__( self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] ): @@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): self.deferred.errback(BodyExceededMaxSize()) # Close the connection (forcefully) since all the data will get # discarded anyway. + assert self.transport is not None self.transport.abortConnection() - def connectionLost(self, reason: Failure) -> None: + def connectionLost(self, reason: Failure = connectionDone) -> None: # If the maximum size was already exceeded, there's nothing to do. if self.deferred.called: return diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index ee909f3fc5..a8894beadf 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -302,7 +302,7 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, + hs.config.redis.redis_host.encode(), hs.config.redis.redis_port, self._factory, ) @@ -311,7 +311,7 @@ class ReplicationCommandHandler: self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + hs.get_reactor().connectTCP(host.encode(), port, self._factory) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 8e4734b59c..825900f64c 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -56,6 +56,7 @@ from prometheus_client import Counter from zope.interface import Interface, implementer from twisted.internet import task +from twisted.internet.tcp import Connection from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): (if they send a `PING` command) """ + # The transport is going to be an ITCPTransport, but that doesn't have the + # (un)registerProducer methods, those are only on the implementation. + transport = None # type: Connection + delimiter = b"\n" # Valid commands we expect to receive @@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): connected_connections.append(self) # Register connection for metrics + assert self.transport is not None self.transport.registerProducer(self, True) # For the *Producing callbacks self._send_pending_commands() @@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.info( "[%s] Failed to close connection gracefully, aborting", self.id() ) + assert self.transport is not None self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: @@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def close(self): logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() + assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() @@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): + assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: connection_close_counter.labels(reason.__class__.__name__).inc() diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 7cccde097d..2f4d407f94 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -365,6 +365,6 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None) + reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) return factory.handler diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 20940c8107..67b7913666 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple - -import attr +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, request_factory, self.site) + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server_to_client_transport.loseConnection() client_to_server_transport.loseConnection() - return request_factory.request + return channel.request def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str @@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - "localhost", + b"localhost", 6379, self.connect_any_redis_attempts, ) @@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs]) + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") + self.assertEqual(host, b"localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) @@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): self.received_rdata_rows.append((stream_name, token, r)) -@attr.s() -class OneShotRequestFactory: - """A simple request factory that generates a single `SynapseRequest` and - stores it for future use. Can only be used once. - """ - - request = attr.ib(default=None) - - def __call__(self, *args, **kwargs): - assert self.request is None - - self.request = SynapseRequest(*args, **kwargs) - return self.request - - class _PushHTTPChannel(HTTPChannel): """A HTTPChannel that wraps pull producers to push producers. @@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel): """ def __init__( - self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site + self, reactor: IReactorTime, request_factory: Type[Request], site: Site ): super().__init__() self.reactor = reactor @@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel): request.responseHeaders.setRawHeaders(b"connection", [b"close"]) return False + def requestDone(self, request): + # Store the request for inspection. + self.request = request + super().requestDone(request) + class _PullToPushProducer: """A push producer that wraps a pull producer.""" @@ -597,6 +581,8 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" + transport = None # type: Optional[FakeTransport] + def __init__(self, server: FakeRedisPubSubServer): self._server = server self._reader = hiredis.Reader() @@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol): def send(self, msg): """Send a message back to the client.""" + assert self.transport is not None + raw = self.encode(msg).encode("utf-8") self.transport.write(raw) diff --git a/tests/server.py b/tests/server.py index 863f6da738..2287d20076 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,6 +16,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IReactorTCP, IResolverSimple, + ITransport, ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock @@ -467,6 +468,7 @@ def get_clock(): return clock, hs_clock +@implementer(ITransport) @attr.s(cmp=False) class FakeTransport: """ -- cgit 1.5.1 From f87dfb94034087f5392670b6e857600102ea3f39 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Mar 2021 12:18:35 -0400 Subject: Revert requiring a specific version of Twisted for mypy checks. (#9618) --- changelog.d/9618.misc | 1 + tox.ini | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) create mode 100644 changelog.d/9618.misc diff --git a/changelog.d/9618.misc b/changelog.d/9618.misc new file mode 100644 index 0000000000..14c7b78dd9 --- /dev/null +++ b/changelog.d/9618.misc @@ -0,0 +1 @@ +Fix incorrect type hints. diff --git a/tox.ini b/tox.ini index a6d10537ae..9ff70fe312 100644 --- a/tox.ini +++ b/tox.ini @@ -189,7 +189,5 @@ commands= [testenv:mypy] deps = {[base]deps} - # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513 - twisted==20.3.0 extras = all,mypy commands = mypy -- cgit 1.5.1 From 1c8a2541dae989f7c17bff8be1e0966df4bc2363 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:20:20 +0000 Subject: Fix Internal Server Error on `GET /saml2/authn_response` (#9623) * Fix Internal Server Error on `GET /saml2/authn_response` Seems to have been introduced in #8765 (Synapse 1.24.0) * Fix newsfile --- changelog.d/9623.bugfix | 1 + synapse/rest/synapse/client/saml2/response_resource.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9623.bugfix diff --git a/changelog.d/9623.bugfix b/changelog.d/9623.bugfix new file mode 100644 index 0000000000..ecccb46105 --- /dev/null +++ b/changelog.d/9623.bugfix @@ -0,0 +1 @@ +Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py index f6668fb5e3..4dfadf1bfb 100644 --- a/synapse/rest/synapse/client/saml2/response_resource.py +++ b/synapse/rest/synapse/client/saml2/response_resource.py @@ -14,24 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.http.server import DirectServeHtmlResource +if TYPE_CHECKING: + from synapse.server import HomeServer + class SAML2ResponseResource(DirectServeHtmlResource): """A Twisted web resource which handles the SAML response""" isLeaf = 1 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._saml_handler = hs.get_saml_handler() + self._sso_handler = hs.get_sso_handler() async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - self._saml_handler._render_error( + self._sso_handler.render_error( request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) -- cgit 1.5.1 From 1b0eaed21f38a5e84089df4ba2fb881fc1242a2e Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:27:51 +0000 Subject: Prevent bundling aggregations for state events (#9619) There's no need to do aggregation bundling for state events. Doing so can cause performance issues. --- changelog.d/9619.misc | 1 + synapse/rest/admin/rooms.py | 5 ++++- synapse/rest/client/v1/room.py | 5 ++++- 3 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9619.misc diff --git a/changelog.d/9619.misc b/changelog.d/9619.misc new file mode 100644 index 0000000000..50267bfbc4 --- /dev/null +++ b/changelog.d/9619.misc @@ -0,0 +1 @@ +Prevent attempting to bundle aggregations for state events in /context APIs. \ No newline at end of file diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f2c42a0f30..263d8ec076 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -685,7 +685,10 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], time_now + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_aggregations=False, ) return 200, results diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9a1df30c29..5884daea6d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -671,7 +671,10 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], time_now + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_aggregations=False, ) return 200, results -- cgit 1.5.1 From 5b5bc188cfc3c31be83058ccbe5a17bfde1b0021 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:57:54 +0000 Subject: Clean up config settings for stats (#9604) ... and complain if people try to turn it off. --- changelog.d/9604.doc | 1 + changelog.d/9604.misc | 1 + docs/sample_config.yaml | 25 +++++++++++++------------ synapse/config/stats.py | 45 ++++++++++++++++++++++++++++----------------- 4 files changed, 43 insertions(+), 29 deletions(-) create mode 100644 changelog.d/9604.doc create mode 100644 changelog.d/9604.misc diff --git a/changelog.d/9604.doc b/changelog.d/9604.doc new file mode 100644 index 0000000000..d413e38b72 --- /dev/null +++ b/changelog.d/9604.doc @@ -0,0 +1 @@ +Clarify the sample configuration for `stats` settings. diff --git a/changelog.d/9604.misc b/changelog.d/9604.misc new file mode 100644 index 0000000000..0583988588 --- /dev/null +++ b/changelog.d/9604.misc @@ -0,0 +1 @@ +Remove unused `stats.retention` setting, and emit a warning if stats are disabled. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c32ee4a897..41ab35595b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2645,19 +2645,20 @@ user_directory: -# Local statistics collection. Used in populating the room directory. +# Settings for local room and user statistics collection. See +# docs/room_and_user_statistics.md. # -# 'bucket_size' controls how large each statistics timeslice is. It can -# be defined in a human readable short form -- e.g. "1d", "1y". -# -# 'retention' controls how long historical statistics will be kept for. -# It can be defined in a human readable short form -- e.g. "1d", "1y". -# -# -#stats: -# enabled: true -# bucket_size: 1d -# retention: 1y +stats: + # Uncomment the following to disable room and user statistics. Note that doing + # so may cause certain features (such as the room directory) not to work + # correctly. + # + #enabled: false + + # The size of each timeslice in the room_stats_historical and + # user_stats_historical tables, as a time period. Defaults to "1d". + # + #bucket_size: 1h # Server Notices room configuration diff --git a/synapse/config/stats.py b/synapse/config/stats.py index b559bfa411..2258329a52 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -13,10 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys +import logging from ._base import Config +ROOM_STATS_DISABLED_WARN = """\ +WARNING: room/user statistics have been disabled via the stats.enabled +configuration setting. This means that certain features (such as the room +directory) will not operate correctly. Future versions of Synapse may ignore +this setting. + +To fix this warning, remove the stats.enabled setting from your configuration +file. +--------------------------------------------------------------------------------""" + +logger = logging.getLogger(__name__) + class StatsConfig(Config): """Stats Configuration @@ -28,30 +40,29 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True self.stats_bucket_size = 86400 * 1000 - self.stats_retention = sys.maxsize stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) self.stats_bucket_size = self.parse_duration( stats_config.get("bucket_size", "1d") ) - self.stats_retention = self.parse_duration( - stats_config.get("retention", "%ds" % (sys.maxsize,)) - ) + if not self.stats_enabled: + logger.warning(ROOM_STATS_DISABLED_WARN) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ - # Local statistics collection. Used in populating the room directory. + # Settings for local room and user statistics collection. See + # docs/room_and_user_statistics.md. # - # 'bucket_size' controls how large each statistics timeslice is. It can - # be defined in a human readable short form -- e.g. "1d", "1y". - # - # 'retention' controls how long historical statistics will be kept for. - # It can be defined in a human readable short form -- e.g. "1d", "1y". - # - # - #stats: - # enabled: true - # bucket_size: 1d - # retention: 1y + stats: + # Uncomment the following to disable room and user statistics. Note that doing + # so may cause certain features (such as the room directory) not to work + # correctly. + # + #enabled: false + + # The size of each timeslice in the room_stats_historical and + # user_stats_historical tables, as a time period. Defaults to "1d". + # + #bucket_size: 1h """ -- cgit 1.5.1 From dd69110d9588b5fc8cca10cba9509d80f88b84f4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:21:26 +0000 Subject: Add support for stable MSC2858 API (#9617) The stable format uses different brand identifiers, so we need to support two identifiers for each IdP. --- changelog.d/9617.feature | 1 + docs/openid.md | 8 +++---- docs/sample_config.yaml | 2 +- synapse/config/oidc_config.py | 13 ++++++++++-- synapse/handlers/cas_handler.py | 1 + synapse/handlers/oidc_handler.py | 3 +++ synapse/handlers/saml_handler.py | 1 + synapse/handlers/sso.py | 5 +++++ synapse/rest/client/v1/login.py | 39 +++++++++++++++++++++++++++++----- tests/rest/client/v1/test_login.py | 43 ++++++++++++++++++++++++-------------- 10 files changed, 88 insertions(+), 28 deletions(-) create mode 100644 changelog.d/9617.feature diff --git a/changelog.d/9617.feature b/changelog.d/9617.feature new file mode 100644 index 0000000000..b462a32b92 --- /dev/null +++ b/changelog.d/9617.feature @@ -0,0 +1 @@ +Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/docs/openid.md b/docs/openid.md index 01205d1220..cfaafc5015 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -226,7 +226,7 @@ Synapse config: oidc_providers: - idp_id: github idp_name: Github - idp_brand: "org.matrix.github" # optional: styling hint for clients + idp_brand: "github" # optional: styling hint for clients discover: false issuer: "https://github.com/" client_id: "your-client-id" # TO BE FILLED @@ -252,7 +252,7 @@ oidc_providers: oidc_providers: - idp_id: google idp_name: Google - idp_brand: "org.matrix.google" # optional: styling hint for clients + idp_brand: "google" # optional: styling hint for clients issuer: "https://accounts.google.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -299,7 +299,7 @@ Synapse config: oidc_providers: - idp_id: gitlab idp_name: Gitlab - idp_brand: "org.matrix.gitlab" # optional: styling hint for clients + idp_brand: "gitlab" # optional: styling hint for clients issuer: "https://gitlab.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -334,7 +334,7 @@ Synapse config: ```yaml - idp_id: facebook idp_name: Facebook - idp_brand: "org.matrix.facebook" # optional: styling hint for clients + idp_brand: "facebook" # optional: styling hint for clients discover: false issuer: "https://facebook.com" client_id: "your-client-id" # TO BE FILLED diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 41ab35595b..7de000f4a4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1919,7 +1919,7 @@ oidc_providers: # #- idp_id: github # idp_name: Github - # idp_brand: org.matrix.github + # idp_brand: github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 7f5e449eb2..2bfb537c15 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -237,7 +237,7 @@ class OIDCConfig(Config): # #- idp_id: github # idp_name: Github - # idp_brand: org.matrix.github + # idp_brand: github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "idp_icon": {"type": "string"}, "idp_brand": { "type": "string", - # MSC2758-style namespaced identifier + "minLength": 1, + "maxLength": 255, + "pattern": "^[a-z][a-z0-9_.-]*$", + }, + "idp_unstable_brand": { + "type": "string", "minLength": 1, "maxLength": 255, "pattern": "^[a-z][a-z0-9_.-]*$", @@ -466,6 +471,7 @@ def _parse_oidc_config_dict( idp_name=oidc_config.get("idp_name", "OIDC"), idp_icon=idp_icon, idp_brand=oidc_config.get("idp_brand"), + unstable_idp_brand=oidc_config.get("unstable_idp_brand"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -512,6 +518,9 @@ class OidcProviderConfig: # Optional brand identifier for this IdP. idp_brand = attr.ib(type=Optional[str]) + # Optional brand identifier for the unstable API (see MSC2858). + unstable_idp_brand = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 04972f9cf0..cb67589f7d 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -83,6 +83,7 @@ class CasHandler: # the SsoIdentityProvider protocol type. self.idp_icon = None self.idp_brand = None + self.unstable_idp_brand = None self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f5d1821127..01c91f9d1c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -330,6 +330,9 @@ class OidcProvider: # optional brand identifier for this auth provider self.idp_brand = provider.idp_brand + # Optional brand identifier for the unstable API (see MSC2858). + self.unstable_idp_brand = provider.unstable_idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a9645b77d8..ec2ba11c75 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -81,6 +81,7 @@ class SamlHandler(BaseHandler): # the SsoIdentityProvider protocol type. self.idp_icon = None self.idp_brand = None + self.unstable_idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 6ef459acff..415b1c2d17 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol): """Optional branding identifier""" return None + @property + def unstable_idp_brand(self) -> Optional[str]: + """Optional brand identifier for the unstable API (see MSC2858).""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 34bc1bd49b..e4c352f572 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService from synapse.handlers.sso import SsoIdentityProvider from synapse.http import get_request_uri @@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + sso_flow = { + "type": LoginRestServlet.SSO_TYPE, + "identity_providers": [ + _get_auth_flow_dict_for_idp( + idp, + ) + for idp in self._sso_handler.get_identity_providers().values() + ], + } # type: JsonDict if self._msc2858_enabled: + # backwards-compatibility support for clients which don't + # support the stable API yet sso_flow["org.matrix.msc2858.identity_providers"] = [ - _get_auth_flow_dict_for_idp(idp) + _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) for idp in self._sso_handler.get_identity_providers().values() ] @@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet): return result -def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: +def _get_auth_flow_dict_for_idp( + idp: SsoIdentityProvider, use_unstable_brands: bool = False +) -> JsonDict: """Return an entry for the login flow dict Returns an entry suitable for inclusion in "identity_providers" in the response to GET /_matrix/client/r0/login + + Args: + idp: the identity provider to describe + use_unstable_brands: whether we should use brand identifiers suitable + for the unstable API """ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: e["brand"] = idp.idp_brand + # use the stable brand identifier if the unstable identifier isn't defined. + if use_unstable_brands and idp.unstable_idp_brand: + e["brand"] = idp.unstable_idp_brand return e class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) + PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ + re.compile( + "^" + + CLIENT_API_PREFIX + + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" + ) + ] def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet): def register(self, http_server: HttpServer) -> None: super().register(http_server) if self._msc2858_enabled: - # expose additional endpoint for MSC2858 support + # expose additional endpoint for MSC2858 support: backwards-compat support + # for clients which don't yet support the stable endpoints. http_server.register_paths( "GET", client_patterns( diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 20af3285bd..988821b16f 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.sso"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - self.assertCountEqual(channel.json_body["flows"], expected_flows) + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_get_msc2858_login_flows(self): @@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(True, "xxx") + channel = self._make_sso_redirect_request(False, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] @@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def _make_sso_redirect_request( self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None ): -- cgit 1.5.1 From 1383508f2956345fd86de1779dd2f6e723c536c5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 16 Mar 2021 07:29:35 -0400 Subject: Handle an empty cookie as an invalid macaroon. (#9620) * Handle an empty cookie as an invalid macaroon. * Newsfragment --- changelog.d/9620.bugfix | 1 + synapse/handlers/oidc_handler.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9620.bugfix diff --git a/changelog.d/9620.bugfix b/changelog.d/9620.bugfix new file mode 100644 index 0000000000..427580f4ad --- /dev/null +++ b/changelog.d/9620.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 01c91f9d1c..6d8551a6d6 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -29,6 +29,7 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( MacaroonDeserializationException, + MacaroonInitException, MacaroonInvalidSignatureException, ) from typing_extensions import TypedDict @@ -217,7 +218,7 @@ class OidcHandler: session_data = self._token_generator.verify_oidc_session_token( session, state ) - except (MacaroonDeserializationException, KeyError) as e: + except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e: logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return -- cgit 1.5.1 From ccf1dc51d7a2f261f91ef8d1442a7ddeccc632b8 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Tue, 16 Mar 2021 12:32:18 +0100 Subject: Install jemalloc in docker image (#8553) Co-authored-by: Will Hunt Co-authored-by: Erik Johnston --- changelog.d/8553.docker | 1 + docker/Dockerfile | 1 + docker/README.md | 5 +++++ docker/start.py | 12 ++++++++++-- 4 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8553.docker diff --git a/changelog.d/8553.docker b/changelog.d/8553.docker new file mode 100644 index 0000000000..f99c4207b8 --- /dev/null +++ b/changelog.d/8553.docker @@ -0,0 +1 @@ +Use jemalloc if available in docker. diff --git a/docker/Dockerfile b/docker/Dockerfile index d619ee08ed..def4501541 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,6 +69,7 @@ RUN apt-get update && apt-get install -y \ libpq5 \ libwebp6 \ xmlsec1 \ + libjemalloc2 \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local diff --git a/docker/README.md b/docker/README.md index 7b138df4d3..3a7dc585e7 100644 --- a/docker/README.md +++ b/docker/README.md @@ -204,3 +204,8 @@ healthcheck: timeout: 10s retries: 3 ``` + +## Using jemalloc + +Jemalloc is embedded in the image and will be used instead of the default allocator. +You can read about jemalloc by reading the Synapse [README](../README.md) \ No newline at end of file diff --git a/docker/start.py b/docker/start.py index 0d2c590b88..16d6a8208a 100755 --- a/docker/start.py +++ b/docker/start.py @@ -3,6 +3,7 @@ import codecs import glob import os +import platform import subprocess import sys @@ -213,6 +214,13 @@ def main(args, environ): if "-m" not in args: args = ["-m", synapse_worker] + args + jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),) + + if os.path.isfile(jemallocpath): + environ["LD_PRELOAD"] = jemallocpath + else: + log("Could not find %s, will not use" % (jemallocpath,)) + # if there are no config files passed to synapse, try adding the default file if not any(p.startswith("--config-path") or p.startswith("-c") for p in args): config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") @@ -248,9 +256,9 @@ running with 'migrate_config'. See the README for more details. args = ["python"] + args if ownership is not None: args = ["gosu", ownership] + args - os.execv("/usr/sbin/gosu", args) + os.execve("/usr/sbin/gosu", args, environ) else: - os.execv("/usr/local/bin/python", args) + os.execve("/usr/local/bin/python", args, environ) if __name__ == "__main__": -- cgit 1.5.1 From 847ecdd8fafe4af3e073449202a6b2fcc47df622 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 16 Mar 2021 12:41:41 +0000 Subject: Pass SSO IdP information to spam checker's registration function (#9626) Fixes https://github.com/matrix-org/synapse/issues/9572 When a SSO user logs in for the first time, we create a local Matrix user for them. This goes through the register_user flow, which ends up triggering the spam checker. Spam checker modules don't currently have any way to differentiate between a user trying to sign up initially, versus an SSO user (whom has presumably already been approved elsewhere) trying to log in for the first time. This PR passes `auth_provider_id` as an argument to the `check_registration_for_spam` function. This argument will contain an ID of an SSO provider (`"saml"`, `"cas"`, etc.) if one was used, else `None`. --- changelog.d/9626.feature | 1 + docs/spam_checker.md | 8 +++++++- synapse/events/spamcheck.py | 29 ++++++++++++++++++++++++++--- synapse/handlers/register.py | 4 ++-- tests/handlers/test_register.py | 31 +++++++++++++++++++++++++++++++ 5 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 changelog.d/9626.feature diff --git a/changelog.d/9626.feature b/changelog.d/9626.feature new file mode 100644 index 0000000000..eacba6201b --- /dev/null +++ b/changelog.d/9626.feature @@ -0,0 +1 @@ +Tell spam checker modules about the SSO IdP a user registered through if one was used. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 2020eb9006..52947f605e 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -69,7 +69,13 @@ class ExampleSpamChecker: async def check_username_for_spam(self, user_profile): return False # allow all usernames - async def check_registration_for_spam(self, email_threepid, username, request_info): + async def check_registration_for_spam( + self, + email_threepid, + username, + request_info, + auth_provider_id, + ): return RegistrationBehaviour.ALLOW # allow all registrations async def check_media_file_for_spam(self, file_wrapper, file_info): diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 8cfc0bb3cb..a9185987a2 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.rest.media.v1._base import FileInfo @@ -27,6 +28,8 @@ if TYPE_CHECKING: import synapse.events import synapse.server +logger = logging.getLogger(__name__) + class SpamChecker: def __init__(self, hs: "synapse.server.HomeServer"): @@ -190,6 +193,7 @@ class SpamChecker: email_threepid: Optional[dict], username: Optional[str], request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str] = None, ) -> RegistrationBehaviour: """Checks if we should allow the given registration request. @@ -198,6 +202,9 @@ class SpamChecker: username: The request user name, if any request_info: List of tuples of user agent and IP that were used during the registration process. + auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", + "cas". If any. Note this does not include users registered + via a password provider. Returns: Enum for how the request should be handled @@ -208,9 +215,25 @@ class SpamChecker: # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = await maybe_awaitable( - checker(email_threepid, username, request_info) - ) + # Provide auth_provider_id if the function supports it + checker_args = inspect.signature(checker) + if len(checker_args.parameters) == 4: + d = checker( + email_threepid, + username, + request_info, + auth_provider_id, + ) + elif len(checker_args.parameters) == 3: + d = checker(email_threepid, username, request_info) + else: + logger.error( + "Invalid signature for %s.check_registration_for_spam. Denying registration", + spam_checker.__module__, + ) + return RegistrationBehaviour.DENY + + behaviour = await maybe_awaitable(d) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cd001e87c7..1abc8875cb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -198,8 +198,7 @@ class RegistrationHandler(BaseHandler): admin api, otherwise False. user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. Returns: The registered user_id. Raises: @@ -211,6 +210,7 @@ class RegistrationHandler(BaseHandler): threepid, localpart, user_agent_ips or [], + auth_provider_id=auth_provider_id, ) if result == RegistrationBehaviour.DENY: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index bdf3d0a8a2..94b6903594 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(requester.shadow_banned) + def test_spam_checker_receives_sso_type(self): + """Test rejecting registration based on SSO type""" + + class BanBadIdPUser: + def check_registration_for_spam( + self, email_threepid, username, request_info, auth_provider_id=None + ): + # Reject any user coming from CAS and whose username contains profanity + if auth_provider_id == "cas" and "flimflob" in username: + return RegistrationBehaviour.DENY + return RegistrationBehaviour.ALLOW + + # Configure a spam checker that denies a certain user on a specific IdP + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanBadIdPUser()] + + f = self.get_failure( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), + SynapseError, + ) + exception = f.value + + # We return 429 from the spam checker for denied registrations + self.assertIsInstance(exception, SynapseError) + self.assertEqual(exception.code, 429) + + # Check the same username can register using SAML + self.get_success( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") + ) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): -- cgit 1.5.1