From 1bf2832714abdfc5e10395e8e76aecc591ad265f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 7 Oct 2022 11:39:45 -0500 Subject: Indicate what endpoint came back with a JSON response we were unable to parse (#14097) **Before:** ``` WARNING - POST-11 - Unable to parse JSON: Expecting value: line 1 column 1 (char 0) (b'') ``` **After:** ``` WARNING - POST-11 - Unable to parse JSON from POST /_matrix/client/v3/join/%21ZlmJtelqFroDRJYZaq:hs1?server_name=hs1 response: Expecting value: line 1 column 1 (char 0) (b'') ``` --- It's possible to figure out which endpoint these warnings were coming from before but you had to follow the request ID `POST-11` to the log line that says `Completed request [...]`. Including this key information next to the JSON parsing error makes it much easier to reason whether it matters or not. ``` 2022-09-29T08:23:25.7875506Z synapse_main | 2022-09-29 08:21:10,336 - synapse.http.matrixfederationclient - 299 - INFO - POST-11 - {GET-O-13} [hs1] Completed request: 200 OK in 0.53 secs, got 450 bytes - GET matrix://hs1/_matrix/federation/v1/make_join/%21ohtKoQiXlPePSycXwp%3Ahs1/%40charlie%3Ahs2?ver=1&ver=2&ver=3&ver=4&ver=5&ver=6&ver=org.matrix.msc2176&ver=7&ver=8&ver=9&ver=org.matrix.msc3787&ver=10&ver=org.matrix.msc2716v4 ``` --- As a note, having no `body` is normal for the `/join` endpoint and it can handle it. https://github.com/matrix-org/synapse/blob/0c853e09709d52783efd37060ed9e8f55a4fc704/synapse/rest/client/room.py#L398-L403 Alternatively we could remove these extra logs but they are probably more usually helpful to figure out what went wrong. --- tests/http/test_servlet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 3cbca0f5a3..46166292fe 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -35,11 +35,13 @@ from tests.http.server._base import test_disconnect def make_request(content): """Make an object that acts enough like a request.""" - request = Mock(spec=["content"]) + request = Mock(spec=["method", "uri", "content"]) if isinstance(content, dict): content = json.dumps(content).encode("utf8") + request.method = bytes("STUB_METHOD", "ascii") + request.uri = bytes("/test_stub_uri", "ascii") request.content = BytesIO(content) return request -- cgit 1.5.1 From a9934d48c193bc963e3d232ed83c5cbfa3e5152d Mon Sep 17 00:00:00 2001 From: Abdullah Osama Date: Tue, 11 Oct 2022 14:42:11 +0200 Subject: Making parse_server_name more consistent (#14007) Fixes #12122 --- changelog.d/14007.misc | 1 + synapse/util/stringutils.py | 4 ++-- tests/http/test_endpoint.py | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14007.misc (limited to 'tests') diff --git a/changelog.d/14007.misc b/changelog.d/14007.misc new file mode 100644 index 0000000000..3f0f3afe1c --- /dev/null +++ b/changelog.d/14007.misc @@ -0,0 +1 @@ +Make `parse_server_name` consistent in handling invalid server names. \ No newline at end of file diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 27a363d7e5..4961fe9313 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -86,7 +86,7 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: ValueError if the server name could not be parsed. """ try: - if server_name[-1] == "]": + if server_name and server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -123,7 +123,7 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int] # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == "[": + if host and host[0] == "[": if host[-1] != "]": raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index c8cc21cadd..a801f002a0 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -25,6 +25,8 @@ class ServerNameTestCase(unittest.TestCase): "[0abc:1def::1234]": ("[0abc:1def::1234]", None), "1.2.3.4:1": ("1.2.3.4", 1), "[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080), + ":80": ("", 80), + "": ("", None), } for i, o in test_data.items(): @@ -42,6 +44,7 @@ class ServerNameTestCase(unittest.TestCase): "newline.com\n", ".empty-label.com", "1234:5678:80", # too many colons + ":80", ] for i in test_data: try: -- cgit 1.5.1 From a86b2f6837f0a067b0a014fbf5140e8773b8da2e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 11 Oct 2022 11:18:45 -0700 Subject: Fix a bug where redactions were not being sent over federation if we did not have the original event. (#13813) --- changelog.d/13813.bugfix | 1 + synapse/federation/sender/__init__.py | 29 +++++++++++++++++-------- synapse/handlers/appservice.py | 9 +++++--- synapse/storage/databases/main/events_worker.py | 15 +++++++++---- synapse/storage/databases/main/stream.py | 28 +++++++++++------------- tests/handlers/test_appservice.py | 18 +++++++++------ 6 files changed, 62 insertions(+), 38 deletions(-) create mode 100644 changelog.d/13813.bugfix (limited to 'tests') diff --git a/changelog.d/13813.bugfix b/changelog.d/13813.bugfix new file mode 100644 index 0000000000..23388788ff --- /dev/null +++ b/changelog.d/13813.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where redactions were not being sent over federation if we did not have the original event. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index a6cb3ba58f..774ecd81b6 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender): last_token = await self.store.get_federation_out_pos("events") ( next_token, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( + ) = await self.store.get_all_new_event_ids_stream( last_token, self._last_poked_id, limit=100 ) + event_ids = event_to_received_ts.keys() + event_entries = await self.store.get_unredacted_events_from_cache_or_db( + event_ids + ) + logger.debug( "Handling %i -> %i: %i events to send (current id %i)", last_token, next_token, - len(events), + len(event_entries), self._last_poked_id, ) - if not events and next_token >= self._last_poked_id: + if not event_entries and next_token >= self._last_poked_id: logger.debug("All events processed") break @@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender): await handle_event(event) events_by_room: Dict[str, List[EventBase]] = {} - for event in events: - events_by_room.setdefault(event.room_id, []).append(event) + + for event_id in event_ids: + # `event_entries` is unsorted, so we have to iterate over `event_ids` + # to ensure the events are in the right order + event_cache = event_entries.get(event_id) + if event_cache: + event = event_cache.event + events_by_room.setdefault(event.room_id, []).append(event) await make_deferred_yieldable( defer.gatherResults( @@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender): logger.debug("Successfully handled up to %i", next_token) await self.store.update_federation_out_pos("events", next_token) - if events: + if event_entries: now = self.clock.time_msec() - ts = event_to_received_ts[events[-1].event_id] + last_id = next(reversed(event_ids)) + ts = event_to_received_ts[last_id] assert ts is not None synapse.metrics.event_processing_lag.labels( @@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).set(ts) - events_processed_counter.inc(len(events)) + events_processed_counter.inc(len(event_entries)) event_processing_loop_room_count.labels("federation_sender").inc( len(events_by_room) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 203b62e015..66f5b8d108 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -109,10 +109,13 @@ class ApplicationServicesHandler: last_token = await self.store.get_appservice_last_pos() ( upper_bound, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( - last_token, self.current_max, limit=100, get_prev_content=True + ) = await self.store.get_all_new_event_ids_stream( + last_token, self.current_max, limit=100 + ) + + events = await self.store.get_events_as_list( + event_to_received_ts.keys(), get_prev_content=True ) events_by_room: Dict[str, List[EventBase]] = {} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7cdc9fe98f..d4104462b5 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = await self._get_events_from_cache_or_db( + event_entry_map = await self.get_unredacted_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = await self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self.get_unredacted_events_from_cache_or_db( + [redacted_event_id] + ) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore): return events @cancellable - async def _get_events_from_cache_or_db( - self, event_ids: Iterable[str], allow_rejected: bool = False + async def get_unredacted_events_from_cache_or_db( + self, + event_ids: Iterable[str], + allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. + Note that the events pulled by this function will not have any redactions + applied, and no guarantee is made about the ordering of the events returned. + If events are pulled from the database, they will be cached for future lookups. Unknown events are omitted from the response. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 530f04e149..ffeb2b3683 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1024,28 +1024,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False - ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: + async def get_all_new_event_ids_stream( + self, + from_id: int, + current_id: int, + limit: int, + ) -> Tuple[int, Dict[str, Optional[int]]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all event ids with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return - get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events, event_to_received_ts), where `next_id` + A tuple of (next_id, event_to_received_ts), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). The `event_to_received_ts` is - a dictionary mapping event ID to the event `received_ts`. + a dictionary mapping event ID to the event `received_ts`, sorted by ascending + stream_ordering. """ - def get_all_new_events_stream_txn( + def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( @@ -1070,15 +1073,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, event_to_received_ts upper_bound, event_to_received_ts = await self.db_pool.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = await self.get_events_as_list( - event_to_received_ts.keys(), - get_prev_content=get_prev_content, + "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn ) - return upper_bound, events, event_to_received_ts + return upper_bound, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index af24c4984d..7e4570f990 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_all_new_events_stream.side_effect = [ - make_awaitable((0, [], {})), - make_awaitable((1, [event], {event.event_id: 0})), + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {})), + make_awaitable((1, {event.event_id: 0})), + ] + self.mock_store.get_events_as_list.side_effect = [ + make_awaitable([]), + make_awaitable([event]), ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) @@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_events_stream.side_effect = [ - make_awaitable((0, [event], {event.event_id: 0})), + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {event.event_id: 0})), ] - + self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_events_stream.side_effect = [ + self.mock_store.get_all_new_event_ids_stream.side_effect = [ make_awaitable((0, [event], {event.event_id: 0})), ] -- cgit 1.5.1 From 09be8ab5f9d54fa1a577d8b0028abf8acc28f30d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 06:26:39 -0400 Subject: Remove the experimental implementation of MSC3772. (#14094) MSC3772 has been abandoned. --- changelog.d/14094.removal | 1 + rust/src/push/base_rules.rs | 13 ---- rust/src/push/evaluator.rs | 105 +--------------------------- rust/src/push/mod.rs | 44 +++--------- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 2 - synapse/push/bulk_push_rule_evaluator.py | 64 +---------------- synapse/storage/databases/main/cache.py | 3 - synapse/storage/databases/main/events.py | 5 -- synapse/storage/databases/main/push_rule.py | 15 ++-- synapse/storage/databases/main/relations.py | 53 -------------- tests/push/test_push_rule_evaluator.py | 76 +------------------- 12 files changed, 22 insertions(+), 365 deletions(-) create mode 100644 changelog.d/14094.removal (limited to 'tests') diff --git a/changelog.d/14094.removal b/changelog.d/14094.removal new file mode 100644 index 0000000000..6ef03b1a0f --- /dev/null +++ b/changelog.d/14094.removal @@ -0,0 +1 @@ +Remove the experimental implementation of [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772). diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 2a09cf99ae..63240cacfc 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -257,19 +257,6 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, - PushRule { - rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), - priority_class: 1, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { - rel_type: Cow::Borrowed("m.thread"), - event_type_pattern: None, - sender: None, - sender_type: Some(Cow::Borrowed("user_id")), - })]), - actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), - default: true, - default_enabled: true, - }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index efe88ec76e..0365dd01dc 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - borrow::Cow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeMap; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -49,13 +46,6 @@ pub struct PushRuleEvaluator { /// The `notifications` section of the current power levels in the room. notification_power_levels: BTreeMap, - /// The relations related to the event as a mapping from relation type to - /// set of sender/event type 2-tuples. - relations: BTreeMap>, - - /// Is running "relation" conditions enabled? - relation_match_enabled: bool, - /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, @@ -70,8 +60,6 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - relations: BTreeMap>, - relation_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -83,8 +71,6 @@ impl PushRuleEvaluator { body, room_member_count, notification_power_levels, - relations, - relation_match_enabled, sender_power_level, }) } @@ -203,89 +189,11 @@ impl PushRuleEvaluator { false } } - KnownCondition::RelationMatch { - rel_type, - event_type_pattern, - sender, - sender_type, - } => { - self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? - } }; Ok(result) } - /// Evaluates a relation condition. - fn match_relations( - &self, - rel_type: &str, - sender: &Option>, - sender_type: &Option>, - user_id: Option<&str>, - event_type_pattern: &Option>, - ) -> Result { - // First check if relation matching is enabled... - if !self.relation_match_enabled { - return Ok(false); - } - - // ... and if there are any relations to match against. - let relations = if let Some(relations) = self.relations.get(rel_type) { - relations - } else { - return Ok(false); - }; - - // Extract the sender pattern from the condition - let sender_pattern = if let Some(sender) = sender { - Some(sender.as_ref()) - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - Some(user_id) - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - None - }; - - let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - for (relation_sender, event_type) in relations { - if let Some(pattern) = &mut sender_compiled_pattern { - if !pattern.is_match(relation_sender)? { - continue; - } - } - - if let Some(pattern) = &mut type_compiled_pattern { - if !pattern.is_match(event_type)? { - continue; - } - } - - return Ok(true); - } - - Ok(false) - } - /// Evaluates a `event_match` condition. fn match_event_match( &self, @@ -359,15 +267,8 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - Some(0), - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); + let evaluator = + PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 208b9c0d73..0dabfab8b8 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -275,16 +275,6 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, - #[serde(rename = "org.matrix.msc3772.relation_match")] - RelationMatch { - rel_type: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none", rename = "type")] - event_type_pattern: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender_type: Option>, - }, } impl IntoPy for Condition { @@ -401,21 +391,15 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3772_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new( - push_rules: PushRules, - enabled_map: BTreeMap, - msc3772_enabled: bool, - ) -> Self { + pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { Self { push_rules, enabled_map, - msc3772_enabled, } } @@ -430,25 +414,13 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules - .iter() - .filter(|rule| { - // Ignore disabled experimental push rules - if !self.msc3772_enabled - && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - { - return false; - } - - true - }) - .map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules.iter().map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 5900e61450..f2a61df660 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,9 +25,7 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3772_enabled: bool - ): ... + def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -39,8 +37,6 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - relations: Mapping[str, Set[Tuple[str, str]]], - relation_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e00cb7096c..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index eced182fd5..8d94aeaa32 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,16 +225,11 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7ee..a9f25a5904 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,9 +259,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e15827986..060fe71454 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2024,11 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) - self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), - ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 116abef9de..6b7eec4bf2 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -776,59 +776,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) - - sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} - """ - - def _get_event_relations( - txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result - - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) - @cached() async def get_thread_id(self, event_id: str) -> Optional[str]: """ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8804f0e0d3..decf619466 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Union import frozendict @@ -38,12 +38,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator( - self, - content: JsonDict, - relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, - relations_match_enabled: bool = False, - ) -> PushRuleEvaluator: + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -63,8 +58,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), - relations or {}, - relations_match_enabled, ) def test_display_name(self) -> None: @@ -299,71 +292,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) - def test_relation_match(self) -> None: - """Test the relation_match push rule kind.""" - - # Check if the experimental feature is disabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}} - ) - - # A push rule evaluator with the experimental rule enabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}}, True - ) - - # Check just relation type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and sender. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@user:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@other:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and event type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "type": "m.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check just sender, this fails since rel_type is required. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "sender": "@user:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check sender glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@*:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check event type glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "event_type": "*.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From 3bbe532abb7bfc41467597731ac1a18c0331f539 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 08:02:11 -0400 Subject: Add an API for listing threads in a room. (#13394) Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes. --- changelog.d/13394.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/relations.py | 86 ++++++++++- synapse/rest/client/relations.py | 50 ++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 38 ++++- synapse/storage/databases/main/relations.py | 166 ++++++++++++++++++++- .../schema/main/delta/73/09threads_table.sql | 30 ++++ tests/rest/client/test_relations.py | 151 +++++++++++++++++++ 10 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13394.feature create mode 100644 synapse/storage/schema/main/delta/73/09threads_table.sql (limited to 'tests') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature new file mode 100644 index 0000000000..68de079cf3 --- /dev/null +++ b/changelog.d/13394.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5fa599e70e..d850e54e17 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..1860006536 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,6 +101,9 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + # MSC3856: Threads list API + self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) + # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index cc5e45c241..1fdd7a10bc 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -32,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -482,3 +490,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b31ce5a0d3..d1aa1947a5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,12 +13,15 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.streams.config import PaginationConfig from synapse.types import JsonDict @@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), + limit=limit, + from_token=from_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + if hs.config.experimental.msc3856_enabled: + ThreadsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a9f25a5904..0ce3156c9c 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 060fe71454..6698cbf664 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,13 +2017,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2024,6 +2053,9 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_threads, (room_id,) + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e7fbf950e6..ac9b96ab44 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -29,17 +30,46 @@ from typing import ( import attr from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. + """ + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) + + sql = f""" + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) + + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + @cached() async def get_thread_id(self, event_id: str) -> str: """ diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644 index 0000000000..aa7c5e9a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); + +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 988cdb746d..d595295e2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_ignored_user(self) -> None: + """Events from ignored users should be ignored.""" + # Thread 1 has a reply from an ignored user. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 is created by an ignored user. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Ignore user2. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Only thread 1 is returned. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) -- cgit 1.5.1 From c3e4edb4d6ba33383bc056e3ff22b2d034d3e248 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 07:16:50 -0400 Subject: Stabilize the threads API. (#14175) Stabilize the threads API (MSC3856) by supporting (only) the v1 path for the endpoint. This also marks the API as safe for workers since it is a read-only API. --- changelog.d/13394.feature | 2 +- changelog.d/14175.feature | 1 + docker/configure_workers_and_start.py | 1 + docs/workers.md | 1 + synapse/config/experimental.py | 3 --- synapse/rest/client/relations.py | 9 ++----- tests/rest/client/test_relations.py | 47 +++++++++++++++++++++-------------- 7 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 changelog.d/14175.feature (limited to 'tests') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature index 68de079cf3..df3ce45a76 100644 --- a/changelog.d/13394.feature +++ b/changelog.d/13394.feature @@ -1 +1 @@ -Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/changelog.d/14175.feature b/changelog.d/14175.feature new file mode 100644 index 0000000000..df3ce45a76 --- /dev/null +++ b/changelog.d/14175.feature @@ -0,0 +1 @@ +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 8e7f605b24..d708237f69 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -118,6 +118,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$", "^/_matrix/client/v1/rooms/.*/hierarchy$", "^/_matrix/client/(v1|unstable)/rooms/.*/relations/", + "^/_matrix/client/v1/rooms/.*/threads$", "^/_matrix/client/(api/v1|r0|v3|unstable)/login$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$", diff --git a/docs/workers.md b/docs/workers.md index e8d6cbaf8b..c27b3f8bd5 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -204,6 +204,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/v1/rooms/.*/hierarchy$ ^/_matrix/client/(v1|unstable)/rooms/.*/relations/ + ^/_matrix/client/v1/rooms/.*/threads$ ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1860006536..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,9 +101,6 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) - # MSC3856: Threads list API - self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) - # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d1aa1947a5..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -82,11 +82,7 @@ class RelationPaginationServlet(RestServlet): class ThreadsServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" - ), - ) + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/threads"),) def __init__(self, hs: "HomeServer"): super().__init__() @@ -126,5 +122,4 @@ class ThreadsServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - if hs.config.experimental.msc3856_enabled: - ThreadsServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d595295e2c..f5c1070b2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1710,7 +1710,15 @@ class RelationRedactionTestCase(BaseRelationsTestCase): class ThreadsTestCase(BaseRelationsTestCase): - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]: + return [ + ( + ev["event_id"], + ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"], + ) + for ev in body["chunk"] + ] + def test_threads(self) -> None: """Create threads and ensure the ordering is due to their latest event.""" # Create 2 threads. @@ -1718,32 +1726,37 @@ class ThreadsTestCase(BaseRelationsTestCase): res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) thread_2 = res["event_id"] - self._send_relation(RelationTypes.THREAD, "m.room.test") - self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_1 = channel.json_body["event_id"] + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", parent_id=thread_2 + ) + reply_2 = channel.json_body["event_id"] # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_2, thread_1]) + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) # Update the first thread, the ordering should swap. - self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_3 = channel.json_body["event_id"] channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_1, thread_2]) + # Tuple of (thread ID, latest event ID) for each thread. + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_pagination(self) -> None: """Create threads and paginate through them.""" # Create 2 threads. @@ -1757,7 +1770,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1771,7 +1784,7 @@ class ThreadsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1780,7 +1793,6 @@ class ThreadsTestCase(BaseRelationsTestCase): self.assertNotIn("next_batch", channel.json_body, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_include(self) -> None: """Filtering threads to all or participated in should work.""" # Thread 1 has the user as the root event. @@ -1807,7 +1819,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # All threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1819,14 +1831,13 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only participated threads. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_ignored_user(self) -> None: """Events from ignored users should be ignored.""" # Thread 1 has a reply from an ignored user. @@ -1852,7 +1863,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only thread 1 is returned. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) -- cgit 1.5.1 From 126a15794c95002560709283640ad412636b29b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 08:30:05 -0400 Subject: Do not allow a None-limit on PaginationConfig. (#14146) The callers either set a default limit or manually handle a None-limit later on (by setting a default value). Update the callers to always instantiate PaginationConfig with a default limit and then assume the limit is non-None. --- changelog.d/14146.removal | 1 + synapse/handlers/account_data.py | 2 +- synapse/handlers/initial_sync.py | 27 ++++----------------------- synapse/handlers/pagination.py | 5 ----- synapse/handlers/presence.py | 4 +++- synapse/handlers/receipts.py | 2 +- synapse/handlers/relations.py | 3 --- synapse/handlers/room.py | 2 +- synapse/handlers/typing.py | 2 +- synapse/rest/client/events.py | 4 +++- synapse/rest/client/initial_sync.py | 4 +++- synapse/rest/client/room.py | 4 +++- synapse/storage/databases/main/stream.py | 2 -- synapse/streams/__init__.py | 2 +- synapse/streams/config.py | 12 +++++------- tests/rest/client/test_typing.py | 3 ++- 16 files changed, 29 insertions(+), 50 deletions(-) create mode 100644 changelog.d/14146.removal (limited to 'tests') diff --git a/changelog.d/14146.removal b/changelog.d/14146.removal new file mode 100644 index 0000000000..08fa752897 --- /dev/null +++ b/changelog.d/14146.removal @@ -0,0 +1 @@ +Remove the unstable identifier for [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 0478448b47..fc21d58001 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 860c82c110..9c335e6863 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -57,13 +57,7 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, - Optional[StreamToken], - Optional[StreamToken], - str, - Optional[int], - bool, - bool, + str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -154,11 +148,6 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - if pagin_config.limit is not None: - limit = pagin_config.limit - else: - limit = 10 - serializer_options = SerializeEventConfig(as_client_event=as_client_event) async def handle_room(event: RoomsForUser) -> None: @@ -210,7 +199,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, event.room_id, - limit=limit, + limit=pagin_config.limit, end_token=room_end_token, ), deferred_room_state, @@ -360,15 +349,11 @@ class InitialSyncHandler: member_event_id ) - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - leave_position = await self.store.get_position_for_event(member_event_id) stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( - room_id, limit=limit, end_token=stream_token + room_id, limit=pagin_config.limit, end_token=stream_token ) messages = await filter_events_for_client( @@ -420,10 +405,6 @@ class InitialSyncHandler: now_token = self.hs.get_event_sources().get_current_token() - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - room_members = [ m for m in current_state.values() @@ -467,7 +448,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, room_id, - limit=limit, + limit=pagin_config.limit, end_token=now_token.room_key, ), ), diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1f83bab836..a4ca9cb8b4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -458,11 +458,6 @@ class PaginationHandler: # `/messages` should still works with live tokens when manually provided. assert from_token.room_key.topological is not None - if pagin_config.limit is None: - # This shouldn't happen as we've set a default limit before this - # gets called. - raise Exception("limit not set") - room_token = from_token.room_key async with self.pagination_lock.read(room_id): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4e575ffbaa..2670e561d7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user: UserID, from_key: Optional[int], - limit: Optional[int] = None, + # Having a default limit doesn't match the EventSource API, but some + # callers do not provide it. It is unused in this class. + limit: int = 0, room_ids: Optional[Collection[str]] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4a7ec9e426..ac01582442 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 1fdd7a10bc..0a0c6d938e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -116,9 +116,6 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - # TODO Update pagination config to not allow None limits. - assert pagin_config.limit is not None - # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 57ab05ad25..4e1aacb408 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): self, user: UserID, from_key: RoomStreamToken, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f953691669..a0ea719430 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 916f5230f1..782e7d14e8 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet): raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") - pagin_config = await PaginationConfig.from_request(self.store, request) + pagin_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in args: try: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index cfadcb8e50..9b1bb8b521 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) args: Dict[bytes, List[bytes]] = request.args # type: ignore as_client_event = b"raw" not in args - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index b6dedbed04..01e5079963 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ffeb2b3683..5baffbfe55 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1200,8 +1200,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 806b671305..2dcd43d0a2 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -27,7 +27,7 @@ class EventSource(Generic[K, R]): self, user: UserID, from_key: K, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index f6f7bf3d8b..6df2de919c 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -35,14 +35,14 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] direction: str - limit: Optional[int] + limit: int @classmethod async def from_request( cls, store: "DataStore", request: SynapseRequest, - default_limit: Optional[int] = None, + default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": direction = parse_string( @@ -69,12 +69,10 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") - if limit: - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - - limit = min(int(limit), MAX_LIMIT) + limit = min(limit, MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 61b66d7685..fdc433a8b5 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=UserID.from_string(self.user_id), from_key=0, - limit=None, + # Limit is unused. + limit=0, room_ids=[self.room_id], is_guest=False, ) -- cgit 1.5.1 From d1bdeccb50550ef454067aa01dd9d004c4704633 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 14:05:25 -0400 Subject: Accept threaded receipts for events related to the root event. (#14174) The root node of a thread (and events related to it) are considered "part of a thread" when validating receipts. This allows clients which show the root node in both the main timeline and the threaded timeline to easily send receipts in either. Note that threaded notifications are not created for these events, these events created notifications on the main timeline. --- changelog.d/14174.feature | 1 + synapse/rest/client/receipts.py | 44 ++++++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/relations.py | 98 ++++++++++++++++++++++-- tests/storage/test_relations.py | 111 ++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14174.feature create mode 100644 tests/storage/test_relations.py (limited to 'tests') diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14174.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 14dec7ac4e..18a282b22c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet): ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet): return 200, {} + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: + """ + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. + + A thread ID is considered valid for a given event E if: + + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which + satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main timeline on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread, otherwise False. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index b47fc606c7..ed0be4abe5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7c54ce0b2e..1de62ee9df 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore): Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. @@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore): The event ID of the root event in the thread, if this event is part of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] @@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py new file mode 100644 index 0000000000..cd1d00208b --- /dev/null +++ b/tests/storage/test_relations.py @@ -0,0 +1,111 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import MAIN_TIMELINE +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class RelationsStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + """ + Creates a DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + F <--[m.annotation]-- G + + """ + self._main_store = self.hs.get_datastores().main + + self._create_relation("A", "B", "m.thread") + self._create_relation("B", "C", "m.annotation") + self._create_relation("A", "D", "m.reference") + self._create_relation("D", "E", "m.annotation") + self._create_relation("F", "G", "m.annotation") + + def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None: + self.get_success( + self._main_store.db_pool.simple_insert( + table="event_relations", + values={ + "event_id": event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + }, + ) + ) + + def test_get_thread_id(self) -> None: + """ + Ensure that get_thread_id only searches up the tree for threads. + """ + # The thread itself and children of it return the thread. + thread_id = self.get_success(self._main_store.get_thread_id("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("C")) + self.assertEqual("A", thread_id) + + # But the root and events related to the root do not. + thread_id = self.get_success(self._main_store.get_thread_id("A")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("D")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("E")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + def test_get_thread_id_for_receipts(self) -> None: + """ + Ensure that get_thread_id_for_receipts searches up and down the tree for a thread. + """ + # All of the events are considered related to this thread. + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E")) + self.assertEqual("A", thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) -- cgit 1.5.1 From 40bb37eb27e1841754a297ac1277748de7f6c1cb Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Sat, 15 Oct 2022 00:36:49 -0500 Subject: Stop getting missing `prev_events` after we already know their signature is invalid (#13816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While https://github.com/matrix-org/synapse/pull/13635 stops us from doing the slow thing after we've already done it once, this PR stops us from doing one of the slow things in the first place. Related to - https://github.com/matrix-org/synapse/issues/13622 - https://github.com/matrix-org/synapse/pull/13635 - https://github.com/matrix-org/synapse/issues/13676 Part of https://github.com/matrix-org/synapse/issues/13356 Follow-up to https://github.com/matrix-org/synapse/pull/13815 which tracks event signature failures. With this PR, we avoid the call to the costly `_get_state_ids_after_missing_prev_event` because the signature failure will count as an attempt before and we filter events based on the backoff before calling `_get_state_ids_after_missing_prev_event` now. For example, this will save us 156s out of the 185s total that this `matrix.org` `/messages` request. If you want to see the full Jaeger trace of this, you can drag and drop this `trace.json` into your own Jaeger, https://gist.github.com/MadLittleMods/4b12d0d0afe88c2f65ffcc907306b761 To explain this exact scenario around `/messages` -> backfill, we call `/backfill` and first check the signatures of the 100 events. We see bad signature for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` and `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` (both member events). Then we process the 98 events remaining that have valid signatures but one of the events references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event`. So we have to do the whole `_get_state_ids_after_missing_prev_event` rigmarole which pulls in those same events which fail again because the signatures are still invalid. - `backfill` - `outgoing-federation-request` `/backfill` - `_check_sigs_and_hash_and_fetch` - `_check_sigs_and_hash_and_fetch_one` for each event received over backfill - ❗ `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - ❗ `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - `_process_pulled_events` - `_process_pulled_event` for each validated event - ❗ Event `$Q0iMdqtz3IJYfZQU2Xk2WjB5NDF8Gg8cFSYYyKQgKJ0` references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event` which is missing so we try to get it - `_get_state_ids_after_missing_prev_event` - `outgoing-federation-request` `/state_ids` - ❗ `get_pdu` for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` which fails the signature check again - ❗ `get_pdu` for `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` which fails the signature check --- changelog.d/13816.feature | 1 + synapse/api/errors.py | 21 +++ synapse/handlers/federation.py | 16 ++ synapse/handlers/federation_event.py | 31 ++++ synapse/storage/databases/main/event_federation.py | 54 ++++++ tests/handlers/test_federation_event.py | 201 ++++++++++++++++++++- tests/storage/test_event_federation.py | 64 +++++++ 7 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13816.feature (limited to 'tests') diff --git a/changelog.d/13816.feature b/changelog.d/13816.feature new file mode 100644 index 0000000000..5eaa936b08 --- /dev/null +++ b/changelog.d/13816.feature @@ -0,0 +1 @@ +Stop fetching missing `prev_events` after we already know their signature is invalid. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c606207569..e0873b1913 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -640,6 +640,27 @@ class FederationError(RuntimeError): } +class FederationPullAttemptBackoffError(RuntimeError): + """ + Raised to indicate that we are are deliberately not attempting to pull the given + event over federation because we've already done so recently and are backing off. + + Attributes: + event_id: The event_id which we are refusing to pull + message: A custom error message that gives more context + """ + + def __init__(self, event_ids: List[str], message: Optional[str]): + self.event_ids = event_ids + + if message: + error_message = message + else: + error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + + super().__init__(error_message) + + class HttpResponseException(CodeMessageException): """ Represents an HTTP-level failure of an outbound request diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 44e70c6c3c..5f7e0a1f79 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, LimitExceededError, NotFoundError, @@ -1720,7 +1721,22 @@ class FederationHandler: destination, event ) break + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_sync_partial_state_room: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 + if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f382961099..4300e8dd40 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -44,6 +44,7 @@ from synapse.api.errors import ( AuthError, Codes, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, RequestSendFailed, SynapseError, @@ -567,6 +568,9 @@ class FederationEventHandler: event: partial-state event to be de-partial-stated Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) @@ -901,6 +905,18 @@ class FederationEventHandler: context, backfilled=backfilled, ) + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_process_pulled_event: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: await self._store.record_event_failed_pull_attempt( event.room_id, event_id, str(e) @@ -947,6 +963,9 @@ class FederationEventHandler: The event context. Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -957,6 +976,18 @@ class FederationEventHandler: seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen + # If we've already recently attempted to pull this missing event, don't + # try it again so soon. Since we have to fetch all of the prev_events, we can + # bail early here if we find any to ignore. + prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) + if len(prevs_to_ignore) > 0: + raise FederationPullAttemptBackoffError( + event_ids=prevs_to_ignore, + message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + ) + if not missing_prevs: return await self._state_handler.compute_event_context(event) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6b9a629edd..309a4ba664 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 918010cddb..e448cb1901 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -14,7 +14,7 @@ from typing import Optional from unittest import mock -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion from synapse.event_auth import ( check_state_dependent_auth_rules, @@ -43,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): def make_homeserver(self, reactor, clock): # mock out the federation transport client self.mock_federation_transport_client = mock.Mock( - spec=["get_room_state_ids", "get_room_state", "get_event"] + spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client @@ -459,6 +459,203 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) self.assertIsNotNone(persisted, "pulled event was not persisted at all") + def test_backfill_signature_failure_does_not_fetch_same_prev_event_later( + self, + ) -> None: + """ + Test to make sure we backoff and don't try to fetch a missing prev_event when we + already know it has a invalid signature from checking the signatures of all of + the events in the backfill response. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # Add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + member_event.event_id, + ] + + # We purposely don't run `add_hashes_and_signatures_from_other_server` + # over this because we want the signature check to fail. + pulled_event_without_signatures = make_event_from_dict( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [member_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event_without_signatures"}, + }, + room_version, + ) + + # Create a regular event that should pass except for the + # `pulled_event_without_signatures` in the `prev_event`. + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + member_event.event_id, + pulled_event_without_signatures.event_id, + ], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event"}, + } + ), + room_version, + ) + + # We expect an outbound request to /backfill, so stub that out + self.mock_federation_transport_client.backfill.return_value = make_awaitable( + { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } + ) + + # Keep track of the count and make sure we don't make any of these requests + event_endpoint_requested_count = 0 + room_state_ids_endpoint_requested_count = 0 + room_state_endpoint_requested_count = 0 + + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> None: + nonlocal event_endpoint_requested_count + event_endpoint_requested_count += 1 + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_ids_endpoint_requested_count + room_state_ids_endpoint_requested_count += 1 + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_endpoint_requested_count + room_state_endpoint_requested_count += 1 + + # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in + # the happy path but if the logic is sneaking around what we expect, stub that + # out so we can detect that failure + self.mock_federation_transport_client.get_event.side_effect = get_event + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + # The function under test: try to backfill and process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler().backfill( + self.OTHER_SERVER_NAME, + room_id, + limit=1, + extremities=["$some_extremity"], + ) + ) + + if event_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /event in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_ids_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state_ids in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + # Make sure we only recorded a single failure which corresponds to the signature + # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before + # we process all of the pulled events. + backfill_num_attempts_for_event_without_signatures = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event_without_signatures.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1) + + # And make sure we didn't record a failure for the event that has the missing + # prev_event because we don't want to cause a cascade of failures. Not being + # able to fetch the `prev_events` just means we won't be able to de-outlier the + # pulled event. But we can still use an `outlier` in the state/auth chain for + # another event. So we shouldn't stop a downstream event from trying to pull it. + self.get_failure( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ), + # StoreError: 404: No row found + StoreError, + ) + def test_process_pulled_event_with_rejected_missing_state(self) -> None: """Ensure that we correctly handle pulled events with missing state containing a rejected state event diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 59b8910907..853db930d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,6 +27,8 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.events import _EventInternalMetadata +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -43,6 +45,12 @@ class _BackfillSetupInfo: class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -1122,6 +1130,62 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["insertion_eventA"]) + def test_get_event_ids_to_not_pull_from_backoff( + self, + ): + """ + Test to make sure only event IDs we should backoff from are returned. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + + self.assertEqual(event_ids_to_backoff, ["$failed_event_id"]) + + def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( + self, + ): + """ + Test to make sure no event IDs are returned after the backoff duration has + elapsed. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + # Now advance time by 2 hours so we wait long enough for the single failed + # attempt (2^1 hours). + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + # Since this function only returns events we should backoff from, time has + # elapsed past the backoff range so there is no events to backoff from. + self.assertEqual(event_ids_to_backoff, []) + @attr.s class FakeEvent: -- cgit 1.5.1 From 4283bd1cf9c3da2157c3642a7c4f105e9fac2636 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Oct 2022 11:32:11 -0400 Subject: Support filtering the /messages API by relation type (MSC3874). (#14148) Gated behind an experimental configuration flag. --- changelog.d/14148.feature | 1 + synapse/api/filtering.py | 27 +++++- synapse/config/experimental.py | 3 + synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/stream.py | 29 ++++++- tests/api/test_filtering.py | 63 +++++++++++++- tests/rest/client/test_relations.py | 1 - tests/rest/client/test_rooms.py | 145 ++----------------------------- tests/storage/test_stream.py | 118 ++++++++++++++++++------- 9 files changed, 212 insertions(+), 177 deletions(-) create mode 100644 changelog.d/14148.feature (limited to 'tests') diff --git a/changelog.d/14148.feature b/changelog.d/14148.feature new file mode 100644 index 0000000000..951d0cac80 --- /dev/null +++ b/changelog.d/14148.feature @@ -0,0 +1 @@ +Experimental support for [MSC3874](https://github.com/matrix-org/matrix-spec-proposals/pull/3874). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -117,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4e1fd2bbe7..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -114,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5baffbfe55..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1278,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1294,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a269c477fb..a82c4eed86 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_rel_type(self): + definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_not_rel_type(self): + definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = self.get_success( @@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): events = [ # An event without a relation. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f5c1070b2c..ddf315b894 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ Test that thread replies are still available when the root event is redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3612ebe7b9..71b1637be8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -35,7 +35,6 @@ from synapse.api.constants import ( EventTypes, Membership, PublicRoomsFilterFields, - RelationTypes, RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException @@ -50,6 +49,7 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -2915,149 +2915,20 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - self.helper.join( - room=self.room_id, user=self.second_user_id, tok=self.second_tok - ) - - self.third_user_id = self.register_user("third", "test") - self.third_tok = self.login("third", "test") - self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) - - # An initial event with a relation from second user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 1"}, - tok=self.tok, - ) - self.event_id_1 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.ANNOTATION, - "event_id": self.event_id_1, - "key": "👍", - } - }, - tok=self.second_tok, - ) - - # Another event with a relation from third user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 2"}, - tok=self.tok, - ) - self.event_id_2 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.REFERENCE, - "event_id": self.event_id_2, - } - }, - tok=self.third_tok, - ) - - # An event with no relations. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "No relations"}, - tok=self.tok, - ) - - def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: +class RelationsTestCase(PaginationTestCase): + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" + from_token = self.get_success( + self.from_token.to_string(self.hs.get_datastores().main) + ) channel = self.make_request( "GET", - "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}", access_token=self.tok, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - return channel.json_body["chunk"] - - def test_filter_relation_senders(self) -> None: - # Messages which second user reacted to. - filter = {"related_by_senders": [self.second_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which third user reacted to. - filter = {"related_by_senders": [self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which either user reacted to. - filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_type(self) -> None: - # Messages which have annotations. - filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which have references. - filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which have either annotations or references. - filter = { - "related_by_rel_types": [ - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - ] - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_senders_and_type(self) -> None: - # Messages which second user reacted to. - filter = { - "related_by_senders": [self.second_user_id], - "related_by_rel_types": [RelationTypes.ANNOTATION], - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) + return [ev["event_id"] for ev in channel.json_body["chunk"]] class ContextTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 78663a53fe..34fa810cf6 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -16,7 +16,6 @@ from typing import List from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.types import JsonDict @@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} + config["experimental_features"] = {"msc3874_enabled": True} return config def prepare(self, reactor, clock, homeserver): @@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase): self.third_tok = self.login("third", "test") self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + # Store a token which is after all the room creation events. + self.from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) + ) + # An initial event with a relation from second user. res = self.helper.send_event( room_id=self.room_id, @@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_1 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.second_tok, ) + self.event_id_annotation = res["event_id"] # Another event with a relation from third user. res = self.helper.send_event( @@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_2 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.third_tok, ) + self.event_id_reference = res["event_id"] # An event with no relations. - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "No relations"}, tok=self.tok, ) + self.event_id_none = res["event_id"] - def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) - ) - events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, - from_key=from_token.room_key, + from_key=self.from_token.room_key, to_key=None, - direction="b", + direction="f", limit=10, event_filter=Filter(self.hs, filter), ) ) - return events + return [ev.event_id for ev in events] def test_filter_relation_senders(self): # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which third user reacted to. filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which either user reacted to. filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_type(self): # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which have references. filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which have either annotations or references. filter = { @@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase): ] } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. @@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase): "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) def test_duplicate_relation(self): """An event should only be returned once if there are multiple relations to it.""" @@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase): filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) + + def test_filter_rel_types(self) -> None: + # Messages which are annotations. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_annotation]) + + # Messages which are references. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_reference]) + + # Messages which are either annotations or references. + filter = { + "org.matrix.msc3874.rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertCountEqual( + chunk, + [self.event_id_annotation, self.event_id_reference], + ) + + def test_filter_not_rel_types(self) -> None: + # Messages which are not annotations. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_2, + self.event_id_reference, + self.event_id_none, + ], + ) + + # Messages which are not references. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_annotation, + self.event_id_2, + self.event_id_none, + ], + ) + + # Messages which are neither annotations or references. + filter = { + "org.matrix.msc3874.not_rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) -- cgit 1.5.1 From 828b5502cfdf4f1b20750941714ce95cdb242f0d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:33:21 +0100 Subject: Remove `_get_events_cache` check optimisation from `_have_seen_events_dict` (#14161) --- changelog.d/14161.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 31 +++++++++------------- tests/storage/databases/main/test_events_worker.py | 12 --------- 3 files changed, 14 insertions(+), 30 deletions(-) create mode 100644 changelog.d/14161.bugfix (limited to 'tests') diff --git a/changelog.d/14161.bugfix b/changelog.d/14161.bugfix new file mode 100644 index 0000000000..aed4d9e386 --- /dev/null +++ b/changelog.d/14161.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.30.0 where purging and rejoining a room without restarting in-between would result in a broken room. \ No newline at end of file diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d4104462b5..cfd4780add 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1502,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1524,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 32a798d74b..5773172ab8 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -90,18 +90,6 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_query_via_event_cache(self): - # fetch an event into the event cache - self.get_success(self.store.get_event(self.event_ids[0])) - - # looking it up should now cause no db hits - with LoggingContext(name="test") as ctx: - res = self.get_success( - self.store.have_seen_events(self.room_id, [self.event_ids[0]]) - ) - self.assertEqual(res, {self.event_ids[0]}) - self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_persisting_event_invalidates_cache(self): """ Test to make sure that the `have_seen_event` cache -- cgit 1.5.1 From dc02d9f8c54576d4b41ce51a2704fdd43b582d66 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:33:35 +0100 Subject: Avoid checking the event cache when backfilling events (#14164) --- changelog.d/14164.bugfix | 1 + synapse/handlers/federation_event.py | 47 ++++++++--- synapse/storage/databases/main/events_worker.py | 2 +- tests/handlers/test_federation.py | 105 +++++++++++++++++++++++- 4 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 changelog.d/14164.bugfix (limited to 'tests') diff --git a/changelog.d/14164.bugfix b/changelog.d/14164.bugfix new file mode 100644 index 0000000000..aed4d9e386 --- /dev/null +++ b/changelog.d/14164.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.30.0 where purging and rejoining a room without restarting in-between would result in a broken room. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4300e8dd40..06e41b5cc0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -798,9 +798,42 @@ class FederationEventHandler: ], ) + # Check if we already any of these have these events. + # Note: we currently make a lookup in the database directly here rather than + # checking the event cache, due to: + # https://github.com/matrix-org/synapse/issues/13476 + existing_events_map = await self._store._get_events_from_db( + [event.event_id for event in events] + ) + + new_events = [] + for event in events: + event_id = event.event_id + + # If we've already seen this event ID... + if event_id in existing_events_map: + existing_event = existing_events_map[event_id] + + # ...and the event itself was not previously stored as an outlier... + if not existing_event.event.internal_metadata.is_outlier(): + # ...then there's no need to persist it. We have it already. + logger.info( + "_process_pulled_event: Ignoring received event %s which we " + "have already seen", + event.event_id, + ) + continue + + # While we have seen this event before, it was stored as an outlier. + # We'll now persist it as a non-outlier. + logger.info("De-outliering event %s", event_id) + + # Continue on with the events that are new to us. + new_events.append(event) + # We want to sort these by depth so we process them and # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) + sorted_events = sorted(new_events, key=lambda x: x.depth) for ev in sorted_events: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) @@ -852,18 +885,6 @@ class FederationEventHandler: event_id = event.event_id - existing = await self._store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "_process_pulled_event: Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - try: self._sanity_check_event(event) except SynapseError as err: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cfd4780add..7bc7f2f33e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -374,7 +374,7 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - The event, or None if the event was not found. + The event, or None if the event was not found and allow_none is `True`. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 745750b1d7..d00c69c229 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -19,7 +19,13 @@ from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + NotFoundError, + SynapseError, +) from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.federation.federation_base import event_from_pdu_json @@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.util import Clock from synapse.util.stringutils import random_string @@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) self.get_success(d) + def test_backfill_ignores_known_events(self) -> None: + """ + Tests that events that we already know about are ignored when backfilling. + """ + # Set up users + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # Create a room to backfill events into + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # Build an event to backfill + event = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"body": "hello world", "msgtype": "m.text"}, + "room_id": room_id, + "sender": other_user, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + # Ensure the event is not already in the DB + self.get_failure( + self.store.get_event(event.event_id), + NotFoundError, + ) + + # Backfill the event and check that it has entered the DB. + + # We mock out the FederationClient.backfill method, to pretend that a remote + # server has returned our fake event. + federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) + self.hs.get_federation_client().backfill = federation_client_backfill_mock + + # We also mock the persist method with a side effect of itself. This allows us + # to track when it has been called while preserving its function. + persist_events_and_notify_mock = Mock( + side_effect=self.hs.get_federation_event_handler().persist_events_and_notify + ) + self.hs.get_federation_event_handler().persist_events_and_notify = ( + persist_events_and_notify_mock + ) + + # Small side-tangent. We populate the event cache with the event, even though + # it is not yet in the DB. This is an invalid scenario that can currently occur + # due to not properly invalidating the event cache. + # See https://github.com/matrix-org/synapse/issues/13476. + # + # As a result, backfill should not rely on the event cache to check whether + # we already have an event in the DB. + # TODO: Remove this bit when the event cache is properly invalidated. + cache_entry = EventCacheEntry( + event=event, + redacted_event=None, + ) + self.store._get_event_cache.set_local((event.event_id,), cache_entry) + + # We now call FederationEventHandler.backfill (a separate method) to trigger + # a backfill request. It should receive the fake event. + self.get_success( + self.hs.get_federation_event_handler().backfill( + other_user, + room_id, + limit=10, + extremities=[], + ) + ) + + # Check that our fake event was persisted. + persist_events_and_notify_mock.assert_called_once() + persist_events_and_notify_mock.reset_mock() + + # Now we repeat the backfill, having the homeserver receive the fake event + # again. + self.get_success( + self.hs.get_federation_event_handler().backfill( + other_user, + room_id, + limit=10, + extremities=[], + ), + ) + + # This time, we expect no event persistence to have occurred, as we already + # have this event. + persist_events_and_notify_mock.assert_not_called() + @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) -- cgit 1.5.1 From 4eaf3eb840b8cfa78d970216c74fc128495f08a5 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 18 Oct 2022 16:52:25 +0100 Subject: Implementation of HTTP 307 response for MSC3886 POST endpoint (#14018) Co-authored-by: reivilibre Co-authored-by: Andrew Morgan --- changelog.d/14018.feature | 1 + synapse/config/experimental.py | 7 +- synapse/config/server.py | 4 ++ synapse/handlers/sso.py | 2 +- synapse/http/server.py | 48 ++++++++++--- synapse/http/site.py | 3 + synapse/rest/__init__.py | 2 + synapse/rest/client/rendezvous.py | 74 +++++++++++++++++++ synapse/rest/client/versions.py | 3 + synapse/rest/key/v2/local_key_resource.py | 4 +- synapse/rest/synapse/client/new_user_consent.py | 3 +- synapse/rest/well_known.py | 3 +- tests/logging/test_terse_json.py | 1 + tests/rest/client/test_rendezvous.py | 45 ++++++++++++ tests/server.py | 8 ++- tests/test_server.py | 94 ++++++++++++++++++------- 16 files changed, 257 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14018.feature create mode 100644 synapse/rest/client/rendezvous.py create mode 100644 tests/rest/client/test_rendezvous.py (limited to 'tests') diff --git a/changelog.d/14018.feature b/changelog.d/14018.feature new file mode 100644 index 0000000000..c8454607eb --- /dev/null +++ b/changelog.d/14018.feature @@ -0,0 +1 @@ +Support for redirecting to an implementation of a [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) HTTP rendezvous service. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f9a49451d8..4009add01d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import attr @@ -120,3 +120,8 @@ class ExperimentalConfig(Config): # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) + + # MSC3886: Simple client rendezvous capability + self.msc3886_endpoint: Optional[str] = experimental.get( + "msc3886_endpoint", None + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index f2353ce5fb..ec46ca63ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,6 +207,9 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None + # If true, the listener will return CORS response headers compatible with MSC3886: + # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 + experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -935,6 +938,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), + experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e035677b8a..5943f08e91 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -874,7 +874,7 @@ class SsoHandler: ) async def handle_terms_accepted( - self, request: Request, session_id: str, terms_version: str + self, request: SynapseRequest, session_id: str, terms_version: str ) -> None: """Handle a request to the new-user 'consent' endpoint diff --git a/synapse/http/server.py b/synapse/http/server.py index bcbfac2c9f..b26e34bceb 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -339,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -598,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -763,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -859,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -870,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -942,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/site.py b/synapse/http/site.py index 55a6afce35..3dbd541fed 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -82,6 +82,7 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -622,6 +623,8 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9a2ab99ede..28542cd774 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client import ( receipts, register, relations, + rendezvous, report_event, room, room_batch, @@ -132,3 +133,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py new file mode 100644 index 0000000000..89176b1ffa --- /dev/null +++ b/synapse/rest/client/rendezvous.py @@ -0,0 +1,74 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http.client import TEMPORARY_REDIRECT +from typing import TYPE_CHECKING, Optional + +from synapse.http.server import HttpServer, respond_with_redirect +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RendezvousServlet(RestServlet): + """ + This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) + simple client rendezvous capability that is used by the "Sign in with QR" functionality. + + This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. + + A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. + + Request: + + POST /rendezvous HTTP/1.1 + Content-Type: ... + + ... + + Response: + + HTTP/1.1 307 + Location: + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint + assert ( + redirection_target is not None + ), "Servlet is only registered if there is a redirection target" + self.endpoint = redirection_target.encode("utf-8") + + async def on_POST(self, request: SynapseRequest) -> None: + respond_with_redirect( + request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True + ) + + # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3886_endpoint is not None: + RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4b87ee978a..9b1b72c68a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -116,6 +116,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3881": self.config.experimental.msc3881_enabled, # Adds support for filtering /messages by event relation. "org.matrix.msc3874": self.config.experimental.msc3874_enabled, + # Adds support for simple HTTP rendezvous as per MSC3886 + "org.matrix.msc3886": self.config.experimental.msc3886_endpoint + is not None, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 0c9f042c84..095993415c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -20,9 +20,9 @@ from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.http.server import respond_with_json_bytes +from synapse.http.site import SynapseRequest from synapse.types import JsonDict if TYPE_CHECKING: @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> Optional[int]: + def render_GET(self, request: SynapseRequest) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 1c1c7b3613..22784157e6 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.types import UserID from synapse.util.templates import build_jinja_env @@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource): html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6f7ac54c65..e2174fdfea 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.stringutils import parse_server_name @@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: Request) -> bytes: + def render_GET(self, request: SynapseRequest) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 96f399b7ab..0b0d8737c1 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.site_tag = "test-site" site.server_version_string = "Server v1" site.reactor = Mock() + site.experimental_cors_msc3886 = False request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py new file mode 100644 index 0000000000..ad00a476e1 --- /dev/null +++ b/tests/rest/client/test_rendezvous.py @@ -0,0 +1,45 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.client import rendezvous +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + +endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" + + +class RendezvousServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + rendezvous.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + return self.hs + + def test_disabled(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) + def test_redirect(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 307) + self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) diff --git a/tests/server.py b/tests/server.py index c447d5e4c4..8b1d186219 100644 --- a/tests/server.py +++ b/tests/server.py @@ -266,7 +266,12 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource, reactor: IReactorTime): + def __init__( + self, + resource: IResource, + reactor: IReactorTime, + experimental_cors_msc3886: bool = False, + ): """ Args: @@ -274,6 +279,7 @@ class FakeSite: """ self._resource = resource self.reactor = reactor + self.experimental_cors_msc3886 = experimental_cors_msc3886 def getResourceFor(self, request): return self._resource diff --git a/tests/test_server.py b/tests/test_server.py index 7c66448245..2d9a0257d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase): self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource()) - def _make_request(self, method: bytes, path: bytes) -> FakeChannel: + def _make_request( + self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False + ) -> FakeChannel: """Create a request from the method/path and return a channel with the response.""" # Create a site and query for the resource. site = SynapseSite( "test", "site_tag", - parse_listener_def(0, {"type": "http", "port": 0}), + parse_listener_def( + 0, + { + "type": "http", + "port": 0, + "experimental_cors_msc3886": experimental_cors_msc3886, + }, + ), self.resource, "1.0", max_request_body_size=4096, @@ -239,25 +248,58 @@ class OptionsResourceTests(unittest.TestCase): channel = make_request(self.reactor, site, method, path, shorthand=False) return channel + def _check_cors_standard_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [b"X-Requested-With, Content-Type, Authorization, Date"], + "has correct CORS Headers header", + ) + + def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [ + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match" + ], + "has correct CORS Headers header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"), + [b"ETag, Location, X-Max-Bytes"], + "has correct CORS Expose Headers header", + ) + def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", - ) + self._check_cors_standard_headers(channel) def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" @@ -265,19 +307,17 @@ class OptionsResourceTests(unittest.TestCase): self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", + self._check_cors_standard_headers(channel) + + def test_known_options_request_msc3886(self) -> None: + """An OPTIONS requests to an known URL still returns 204 No Content.""" + channel = self._make_request( + b"OPTIONS", b"/res/", experimental_cors_msc3886=True ) + self.assertEqual(channel.code, 204) + self.assertNotIn("body", channel.result) + + self._check_cors_msc3886_headers(channel) def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" -- cgit 1.5.1 From fa8616e65c82367712a7b75c62682a89541b6330 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 18 Oct 2022 19:46:25 -0500 Subject: Fix MSC3030 `/timestamp_to_event` returning `outliers` that it has no idea whether are near a gap or not (#14215) Fix MSC3030 `/timestamp_to_event` endpoint returning `outliers` that it has no idea whether are near a gap or not (and therefore unable to determine whether it's actually the closest event). The reason Synapse doesn't know whether an `outlier` is next to a gap is because our gap checks rely on entries in the `event_edges`, `event_forward_extremeties`, and `event_backward_extremities` tables which is [not the case for `outliers`](https://github.com/matrix-org/synapse/blob/2c63cdcc3f1aa4625e947de3c23e0a8133c61286/docs/development/room-dag-concepts.md#outliers). Also fixes MSC3030 Complement `can_paginate_after_getting_remote_event_from_timestamp_to_event_endpoint` test flake. Although this acted flakey in Complement, if `sync_partial_state` raced and beat us before `/timestamp_to_event`, then even if we retried the failing `/context` request it wouldn't work until we made this Synapse change. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. Fix https://github.com/matrix-org/synapse/issues/13944 ### Why did this fail before? Why was it flakey? Sleuthing the server logs on the [CI failure](https://github.com/matrix-org/synapse/actions/runs/3149623842/jobs/5121449357#step:5:5805), it looks like `hs2:/timestamp_to_event` found `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` event locally. Then when we went and asked for it via `/context`, since it's an `outlier`, it was filtered out of the results -> `You don't have permission to access that event.` This is reproducible when `sync_partial_state` races and persists `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` before we evaluate `get_event_for_timestamp(...)`. To consistently reproduce locally, just add a delay at the [start of `get_event_for_timestamp(...)`](https://github.com/matrix-org/synapse/blob/cb20b885cb4bd1648581dd043a184d86fc8c7a00/synapse/handlers/room.py#L1470-L1496) so it always runs after `sync_partial_state` completes. ```py from twisted.internet import task as twisted_task d = twisted_task.deferLater(self.hs.get_reactor(), 3.5) await d ``` In a run where it passes, on `hs2`, `get_event_for_timestamp(...)` finds a different event locally which is next to a gap and we request from a closer one from `hs1` which gets backfilled. And since the backfilled event is not an `outlier`, it's returned as expected during `/context`. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. --- changelog.d/14215.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 59 ++++++++++++++-------- tests/rest/client/test_rooms.py | 65 +++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 changelog.d/14215.bugfix (limited to 'tests') diff --git a/changelog.d/14215.bugfix b/changelog.d/14215.bugfix new file mode 100644 index 0000000000..31c109f534 --- /dev/null +++ b/changelog.d/14215.bugfix @@ -0,0 +1 @@ +Fix [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint returning potentially inaccurate closest events with `outliers` present. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7bc7f2f33e..69fea452ad 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1971,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2026,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2112,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2128,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 71b1637be8..716366eb90 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -39,6 +39,8 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, register, room, sync @@ -51,6 +53,7 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import create_event PATH_PREFIX = b"/_matrix/client/api/v1" @@ -3486,3 +3489,65 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") + + +class TimestampLookupTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc3030_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._storage_controllers = self.hs.get_storage_controllers() + + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + def _inject_outlier(self, room_id: str) -> EventBase: + event, _context = self.get_success( + create_event( + self.hs, + room_id=room_id, + type="m.test", + sender="@test_remote_user:remote", + ) + ) + + event.internal_metadata.outlier = True + self.get_success( + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) + ) + ) + return event + + def test_no_outliers(self) -> None: + """ + Test to make sure `/timestamp_to_event` does not return `outlier` events. + We're unable to determine whether an `outlier` is next to a gap so we + don't know whether it's actually the closest event. Instead, let's just + ignore `outliers` with this endpoint. + + This test is really seeing that we choose the non-`outlier` event behind the + `outlier`. Since the gap checking logic considers the latest message in the room + as *not* next to a gap, asking over federation does not come into play here. + """ + room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok) + + outlier_event = self._inject_outlier(room_id) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + access_token=self.room_owner_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Make sure the outlier event is not returned + self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id) -- cgit 1.5.1 From 0b7830e457359ce651b293c8748bf636973404a9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:38:24 +0000 Subject: Bump flake8-bugbear from 21.3.2 to 22.9.23 (#14042) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Erik Johnston Co-authored-by: David Robertson --- .flake8 | 9 ++++++++- changelog.d/14042.misc | 1 + poetry.lock | 8 ++++---- synapse/storage/databases/main/roommember.py | 4 ++-- synapse/util/caches/deferred_cache.py | 4 ++-- synapse/util/caches/descriptors.py | 2 +- tests/federation/transport/test_client.py | 7 +++---- tests/util/caches/test_descriptors.py | 2 +- 8 files changed, 22 insertions(+), 15 deletions(-) create mode 100644 changelog.d/14042.misc (limited to 'tests') diff --git a/.flake8 b/.flake8 index acb118c86e..4c6a4d5843 100644 --- a/.flake8 +++ b/.flake8 @@ -8,4 +8,11 @@ # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 +# +# flake8-bugbear runs extra checks. Its error codes are described at +# https://github.com/PyCQA/flake8-bugbear#list-of-warnings +# B019: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks +# B023: Functions defined inside a loop must not use variables redefined in the loop +# B024: Abstract base class with no abstract method. + +ignore=W503,W504,E203,E731,E501,B019,B023,B024 diff --git a/changelog.d/14042.misc b/changelog.d/14042.misc new file mode 100644 index 0000000000..868d55e76a --- /dev/null +++ b/changelog.d/14042.misc @@ -0,0 +1 @@ +Bump flake8-bugbear from 21.3.2 to 22.9.23. diff --git a/poetry.lock b/poetry.lock index ed0b59fbe5..0a2f9ab69e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -260,7 +260,7 @@ pyflakes = ">=2.4.0,<2.5.0" [[package]] name = "flake8-bugbear" -version = "21.3.2" +version = "22.9.23" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." category = "dev" optional = false @@ -271,7 +271,7 @@ attrs = ">=19.2.0" flake8 = ">=3.0.0" [package.extras] -dev = ["black", "coverage", "hypothesis", "hypothesmith"] +dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit"] [[package]] name = "flake8-comprehensions" @@ -1826,8 +1826,8 @@ flake8 = [ {file = "flake8-4.0.1.tar.gz", hash = "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d"}, ] flake8-bugbear = [ - {file = "flake8-bugbear-21.3.2.tar.gz", hash = "sha256:cadce434ceef96463b45a7c3000f23527c04ea4b531d16c7ac8886051f516ca0"}, - {file = "flake8_bugbear-21.3.2-py36.py37.py38-none-any.whl", hash = "sha256:5d6ccb0c0676c738a6e066b4d50589c408dcc1c5bf1d73b464b18b73cd6c05c2"}, + {file = "flake8-bugbear-22.9.23.tar.gz", hash = "sha256:17b9623325e6e0dcdcc80ed9e4aa811287fcc81d7e03313b8736ea5733759937"}, + {file = "flake8_bugbear-22.9.23-py3-none-any.whl", hash = "sha256:cd2779b2b7ada212d7a322814a1e5651f1868ab0d3f24cc9da66169ab8fda474"}, ] flake8-comprehensions = [ {file = "flake8-comprehensions-3.8.0.tar.gz", hash = "sha256:8e108707637b1d13734f38e03435984f6b7854fa6b5a4e34f93e69534be8e521"}, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2ed6ad754f..32e1e983a5 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -707,8 +707,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): # 250 users is pretty arbitrary but the data can be quite large if users # are in many rooms. - for user_ids in batch_iter(user_ids, 250): - all_user_rooms.update(await self._get_rooms_for_users(user_ids)) + for batch_user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids)) return all_user_rooms diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 6425f851ea..bcb1cba362 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -395,8 +395,8 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. - for entry in iterate_tree_cache_entry(entry): - for cb in entry.get_invalidation_callbacks(key): + for iter_entry in iterate_tree_cache_entry(entry): + for cb in iter_entry.get_invalidation_callbacks(key): cb() def invalidate_all(self) -> None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 0391966462..b3c748ef44 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -432,7 +432,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): num_args = cached_method.num_args if num_args != self.num_args: - raise Exception( + raise TypeError( "Number of args (%s) does not match underlying cache_method_name=%s (%s)." % (self.num_args, self.cached_method_name, num_args) ) diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index 0926e0583d..dd4d1b56de 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -17,6 +17,7 @@ from unittest.mock import Mock from synapse.api.room_versions import RoomVersions from synapse.federation.transport.client import SendJoinParser +from synapse.util import ExceptionBundle from tests.unittest import TestCase @@ -121,10 +122,8 @@ class SendJoinParserTestCase(TestCase): # Send half of the data to the parser parser.write(serialisation[: len(serialisation) // 2]) - # Close the parser. There should be _some_ kind of exception, but it need not - # be that RuntimeError directly. E.g. we might want to raise a wrapper - # encompassing multiple errors from multiple coroutines. - with self.assertRaises(Exception): + # Close the parser. There should be _some_ kind of exception. + with self.assertRaises(ExceptionBundle): parser.finish() # In any case, we should have tried to close both coros. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 90861fe522..78fd7b6961 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -1037,5 +1037,5 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() # Make sure this raises an error about the arg mismatch - with self.assertRaises(Exception): + with self.assertRaises(TypeError): obj.list_fn([("foo", "bar")]) -- cgit 1.5.1 From 755bfeee3a1ac7077045ab9e5a994b6ca89afba3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Oct 2022 11:32:47 -0400 Subject: Use servlets for /key/ endpoints. (#14229) To fix the response for unknown endpoints under that prefix. See MSC3743. --- changelog.d/14229.misc | 1 + synapse/api/urls.py | 2 +- synapse/app/generic_worker.py | 20 +++----- synapse/app/homeserver.py | 26 ++++------ synapse/rest/key/v2/__init__.py | 19 ++++--- synapse/rest/key/v2/local_key_resource.py | 22 ++++---- synapse/rest/key/v2/remote_key_resource.py | 73 +++++++++++++++------------ tests/app/test_openid_listener.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- 9 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14229.misc (limited to 'tests') diff --git a/changelog.d/14229.misc b/changelog.d/14229.misc new file mode 100644 index 0000000000..b9cd9a34d5 --- /dev/null +++ b/changelog.d/14229.misc @@ -0,0 +1 @@ +Refactor `/key/` endpoints to use `RestServlet` classes. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bd49fa6a5f..a918579f50 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" -SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" +SERVER_KEY_PREFIX = "/_matrix/key" MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index dc49840f73..2a9f039367 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -28,7 +28,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, ) from synapse.app import _base from synapse.app._base import ( @@ -89,7 +89,7 @@ from synapse.rest.client.register import ( RegistrationTokenValidityRestServlet, ) from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -325,13 +325,13 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) + resources[CLIENT_API_PREFIX] = resource resources.update(build_synapse_client_resource_tree(self)) - resources.update({"/.well-known": well_known_resource(self)}) + resources["/.well-known"] = well_known_resource(self) elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + resources[FEDERATION_PREFIX] = TransportLayerServer(self) elif name == "media": if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() @@ -359,16 +359,12 @@ class GenericWorkerServer(HomeServer): # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883f2fd2ec..de3f08876f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -31,7 +31,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, STATIC_PREFIX, ) from synapse.app import _base @@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -215,30 +215,22 @@ class SynapseHomeServer(HomeServer): consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({"/_matrix/consent": consent_resource}) + resources["/_matrix/consent"] = consent_resource if name == "federation": federation_resource: Resource = TransportLayerServer(self) if compress: federation_resource = gz_wrap(federation_resource) - resources.update({FEDERATION_PREFIX: federation_resource}) + resources[FEDERATION_PREFIX] = federation_resource if name == "openid": - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["static", "client"]: - resources.update( - { - STATIC_PREFIX: StaticResource( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - } + resources[STATIC_PREFIX] = StaticResource( + os.path.join(os.path.dirname(synapse.__file__), "static") ) if name in ["media", "federation", "client"]: @@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "metrics" and self.config.metrics.enable_metrics: metrics_resource: Resource = MetricsResource(RegistryProxy) diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 7f8c1de1ff..26403facb8 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,17 +14,20 @@ from typing import TYPE_CHECKING -from twisted.web.resource import Resource - -from .local_key_resource import LocalKey -from .remote_key_resource import RemoteKey +from synapse.http.server import HttpServer, JsonResource +from synapse.rest.key.v2.local_key_resource import LocalKey +from synapse.rest.key.v2.remote_key_resource import RemoteKey if TYPE_CHECKING: from synapse.server import HomeServer -class KeyApiV2Resource(Resource): +class KeyResource(JsonResource): def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.putChild(b"server", LocalKey(hs)) - self.putChild(b"query", RemoteKey(hs)) + super().__init__(hs, canonical_json=True) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: + LocalKey(hs).register(http_server) + RemoteKey(hs).register(http_server) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 095993415c..d03e728d42 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,16 +13,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +import re +from typing import TYPE_CHECKING, Optional, Tuple -from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from twisted.web.resource import Resource +from twisted.web.server import Request -from synapse.http.server import respond_with_json_bytes -from synapse.http.site import SynapseRequest +from synapse.http.servlet import RestServlet from synapse.types import JsonDict if TYPE_CHECKING: @@ -31,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LocalKey(Resource): +class LocalKey(RestServlet): """HTTP resource containing encoding the TLS X.509 certificate and NACL signature verification keys for this server:: @@ -61,18 +60,17 @@ class LocalKey(Resource): } """ - isLeaf = True + PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P[^/]*))?$"),) def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) - Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) - self.response_body = encode_canonical_json(self.response_json_object()) + self.response_body = self.response_json_object() def response_json_object(self) -> JsonDict: verify_keys = {} @@ -99,9 +97,11 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: SynapseRequest) -> Optional[int]: + def on_GET( + self, request: Request, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes(request, 200, self.response_body) + return 200, self.response_body diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7f8ad29566..19820886f5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Set +import re +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from signedjson.sign import sign_json -from synapse.api.errors import Codes, SynapseError +from twisted.web.server import Request + from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.http.site import SynapseRequest +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, +) from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -32,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RemoteKey(DirectServeJsonResource): +class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource): } """ - isLeaf = True - def __init__(self, hs: "HomeServer"): - super().__init__() - self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: SynapseRequest) -> None: - assert request.postpath is not None - if len(request.postpath) == 1: - (server,) = request.postpath - query: dict = {server.decode("ascii"): {}} - elif len(request.postpath) == 2: - server, key_id = request.postpath + def register(self, http_server: HttpServer) -> None: + http_server.register_paths( + "GET", + ( + re.compile( + "^/_matrix/key/v2/query/(?P[^/]*)(/(?P[^/]*))?$" + ), + ), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "POST", + (re.compile("^/_matrix/key/v2/query$"),), + self.on_POST, + self.__class__.__name__, + ) + + async def on_GET( + self, request: Request, server: str, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: + if server and key_id: minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} + query = {server: {key_id: arguments}} else: - raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) + query = {server: {}} - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) query = content["server_keys"] - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - request: SynapseRequest, - query: JsonDict, - query_remote_on_cache_miss: bool = False, - ) -> None: + self, query: JsonDict, query_remote_on_cache_miss: bool = False + ) -> JsonDict: logger.info("Handling query for keys %r", query) store_queries = [] @@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource): for server_name, keys in cache_misses.items() ), ) - await self.query_keys(request, query, query_remote_on_cache_miss=False) + return await self.query_keys(query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json_raw in json_results: @@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource): signed_keys.append(key_json) - response = {"server_keys": signed_keys} - - respond_with_json(request, 200, response, canonical_json=True) + return {"server_keys": signed_keys} diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index c7dae58eb5..8d03da7f96 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): self.assertEqual(channel.code, 401) -@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock()) +@patch("synapse.app.homeserver.KeyResource", new=Mock()) class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index ac0ac06b7e..7f1fba1086 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict @@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def create_test_resource(self) -> Resource: return create_resource_tree( - {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource() ) def expect_outgoing_key_request( -- cgit 1.5.1 From 1433b5d5b64c3a6624e6e4ff4fef22127c49df86 Mon Sep 17 00:00:00 2001 From: Tadeusz Sośnierz Date: Fri, 21 Oct 2022 14:52:44 +0200 Subject: Show erasure status when listing users in the Admin API (#14205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Show erasure status when listing users in the Admin API * Use USING when joining erased_users * Add changelog entry * Revert "Use USING when joining erased_users" This reverts commit 30bd2bf106415caadcfdbdd1b234ef2b106cc394. * Make the erased check work on postgres * Add a testcase for showing erased user status * Appease the style linter * Explicitly convert `erased` to bool to make SQLite consistent with Postgres This also adds us an easy way in to fix the other accidentally integered columns. * Move erasure status test to UsersListTestCase * Include user erased status when fetching user info via the admin API * Document the erase status in user_admin_api * Appease the linter and mypy * Signpost comments in tests Co-authored-by: Tadeusz Sośnierz Co-authored-by: David Robertson --- changelog.d/14205.feature | 1 + docs/admin_api/user_admin_api.md | 4 ++++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/__init__.py | 13 +++++++++-- tests/rest/admin/test_user.py | 35 +++++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14205.feature (limited to 'tests') diff --git a/changelog.d/14205.feature b/changelog.d/14205.feature new file mode 100644 index 0000000000..6692063352 --- /dev/null +++ b/changelog.d/14205.feature @@ -0,0 +1 @@ +Show erasure status when listing users in the Admin API. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 3625c7b6c5..c95d6c9b05 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -37,6 +37,7 @@ It returns a JSON body like the following: "is_guest": 0, "admin": 0, "deactivated": 0, + "erased": false, "shadow_banned": 0, "creation_ts": 1560432506, "appservice_id": null, @@ -167,6 +168,7 @@ A response body like the following is returned: "admin": 0, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": null, @@ -177,6 +179,7 @@ A response body like the following is returned: "admin": 1, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": "", @@ -247,6 +250,7 @@ The following fields are returned in the JSON response body: - `user_type` - string - Type of the user. Normal users are type `None`. This allows user type specific behaviour. There are also types `support` and `bot`. - `deactivated` - bool - Status if that user has been marked as deactivated. + - `erased` - bool - Status if that user has been marked as erased. - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f2989cc4a2..5bf8e86387 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -100,6 +100,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids + user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) return user_info_dict diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a62b4abd4e..cfaedf5e0c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -201,7 +201,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: str = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.NAME.value, direction: str = "f", approved: bool = True, ) -> Tuple[List[JsonDict], int]: @@ -261,6 +261,7 @@ class DataStore( sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN erased_users AS eu ON u.name = eu.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -269,7 +270,8 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts, approved + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, + eu.user_id is not null as erased {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? @@ -277,6 +279,13 @@ class DataStore( args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) + + # some of those boolean values are returned as integers when we're on SQLite + columns_to_boolify = ["erased"] + for user in users: + for column in columns_to_boolify: + user[column] = bool(user[column]) + return users, count return await self.db_pool.runInteraction( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4c1ce33463..63410ffdf1 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersions from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -924,6 +924,36 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def test_erasure_status(self) -> None: + # Create a new user. + user_id = self.register_user("eraseme", "eraseme") + + # They should appear in the list users API, marked as not erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], False) + + # Deactivate that user, requesting erasure. + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + user_id, erase_data=True, requester=create_requester(user_id) + ) + ) + + # Repeat the list users query. They should now be marked as erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], True) + def _order_test( self, expected_user_list: List[str], @@ -1195,6 +1225,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) + self.assertFalse(channel.json_body["erased"]) # Deactivate and erase user channel = self.make_request( @@ -1219,6 +1250,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) + self.assertTrue(channel.json_body["erased"]) self._is_erased("@user:test", True) @@ -2757,6 +2789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("avatar_url", content) self.assertIn("admin", content) self.assertIn("deactivated", content) + self.assertIn("erased", content) self.assertIn("shadow_banned", content) self.assertIn("creation_ts", content) self.assertIn("appservice_id", content) -- cgit 1.5.1 From 4dd7aa371b6bc746fa4b0a9af220b2013b17a45d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Oct 2022 09:11:19 -0400 Subject: Properly update the threads table when thread events are redacted. (#14248) When the last event in a thread is redacted we need to update the threads table: * Find the new latest event in the thread and store it into the table; or * Remove the thread from the table if it is no longer a thread (i.e. all events in the thread were redacted). --- changelog.d/14248.bugfix | 1 + synapse/storage/databases/main/events.py | 61 ++++++++++++++--- tests/rest/client/test_relations.py | 110 +++++++++++++++++++++---------- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/14248.bugfix (limited to 'tests') diff --git a/changelog.d/14248.bugfix b/changelog.d/14248.bugfix new file mode 100644 index 0000000000..203c52c16b --- /dev/null +++ b/changelog.d/14248.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where the information returned from the `/threads` API could be stale when threaded events are redacted. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6698cbf664..00880bb37d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2028,25 +2028,37 @@ class PersistEventsStore: redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2057,9 +2069,38 @@ class PersistEventsStore: txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ddf315b894..e3d801f7a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1523,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _get_threads(self) -> List[Tuple[str, str]]: + """Request the threads in the room and returns a list of thread ID and latest event ID.""" + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + threads = channel.json_body["chunk"] + return [ + ( + t["event_id"], + t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][ + "event_id" + ], + ) + for t in threads + ] + def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1567,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase): The redacted event should not be included in bundled aggregations or the response to relations. """ - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 1", "msgtype": "m.text"}, - ) - unredacted_event_id = channel.json_body["event_id"] + # Create a thread with a few events in it. + thread_replies = [] + for i in range(3): + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": f"reply {i}", "msgtype": "m.text"}, + ) + thread_replies.append(channel.json_body["event_id"]) - # Note that the *last* event in the thread is redacted, as that gets - # included in the bundled aggregation. - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 2", "msgtype": "m.text"}, + ################################################## + # Check the test data is configured as expected. # + ################################################## + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + relations = self._get_bundled_aggregations() + self.assertDictContainsSubset( + {"count": 3, "current_user_participated": True}, + relations[RelationTypes.THREAD], + ) + # The latest event is the last sent event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + thread_replies[-1], ) - to_redact_event_id = channel.json_body["event_id"] - # Both relations exist. - event_ids = self._get_related_events() + # There should be one thread, the latest event is the event that will be redacted. + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + ########################## + # Redact the last event. # + ########################## + self._redact(thread_replies.pop()) + + # The thread should still exist, but the latest event should be updated. + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 2, - "current_user_participated": True, - }, + {"count": 2, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event returned is the event that will be redacted. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - to_redact_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) - # Redact one of the reactions. - self._redact(to_redact_event_id) + ########################################### + # Redact the *first* event in the thread. # + ########################################### + self._redact(thread_replies.pop(0)) - # The unredacted relation should still exist. - event_ids = self._get_related_events() + # Nothing should have changed (except the thread count). + self.assertEquals(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 1, - "current_user_participated": True, - }, + {"count": 1, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event is now the unredacted event. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - unredacted_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + #################################### + # Redact the last remaining event. # + #################################### + self._redact(thread_replies.pop(0)) + self.assertEquals(thread_replies, []) + + # The event should no longer be considered a thread. + self.assertEquals(self._get_related_events(), []) + self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event -- cgit 1.5.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'tests') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.5.1 From c9dffd5b330553c5803784be5bc0e2479fab79b0 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 25 Oct 2022 11:39:25 +0100 Subject: Remove unused `@lru_cache` decorator (#13595) * Remove unused `@lru_cache` decorator Spotted this working on something else. Co-authored-by: David Robertson --- changelog.d/13595.misc | 1 + synapse/util/caches/descriptors.py | 104 ---------------------------------- tests/util/caches/test_descriptors.py | 40 ++----------- 3 files changed, 5 insertions(+), 140 deletions(-) create mode 100644 changelog.d/13595.misc (limited to 'tests') diff --git a/changelog.d/13595.misc b/changelog.d/13595.misc new file mode 100644 index 0000000000..71959a6ee7 --- /dev/null +++ b/changelog.d/13595.misc @@ -0,0 +1 @@ +Remove unused `@lru_cache` decorator. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b3c748ef44..75428d19ba 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging @@ -146,109 +145,6 @@ class _CacheDescriptorBase: ) -class _LruCachedFunction(Generic[F]): - cache: LruCache[CacheKey, Any] - __call__: F - - -def lru_cache( - *, max_entries: int = 1000, cache_context: bool = False -) -> Callable[[F], _LruCachedFunction[F]]: - """A method decorator that applies a memoizing cache around the function. - - This is more-or-less a drop-in equivalent to functools.lru_cache, although note - that the signature is slightly different. - - The main differences with functools.lru_cache are: - (a) the size of the cache can be controlled via the cache_factor mechanism - (b) the wrapped function can request a "cache_context" which provides a - callback mechanism to indicate that the result is no longer valid - (c) prometheus metrics are exposed automatically. - - The function should take zero or more arguments, which are used as the key for the - cache. Single-argument functions use that argument as the cache key; otherwise the - arguments are built into a tuple. - - Cached functions can be "chained" (i.e. a cached function can call other cached - functions and get appropriately invalidated when they called caches are - invalidated) by adding a special "cache_context" argument to the function - and passing that as a kwarg to all caches called. For example: - - @lru_cache(cache_context=True) - def foo(self, key, cache_context): - r1 = self.bar1(key, on_invalidate=cache_context.invalidate) - r2 = self.bar2(key, on_invalidate=cache_context.invalidate) - return r1 + r2 - - The wrapped function also has a 'cache' property which offers direct access to the - underlying LruCache. - """ - - def func(orig: F) -> _LruCachedFunction[F]: - desc = LruCacheDescriptor( - orig, - max_entries=max_entries, - cache_context=cache_context, - ) - return cast(_LruCachedFunction[F], desc) - - return func - - -class LruCacheDescriptor(_CacheDescriptorBase): - """Helper for @lru_cache""" - - class _Sentinel(enum.Enum): - sentinel = object() - - def __init__( - self, - orig: Callable[..., Any], - max_entries: int = 1000, - cache_context: bool = False, - ): - super().__init__( - orig, num_args=None, uncached_args=None, cache_context=cache_context - ) - self.max_entries = max_entries - - def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: - cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.name, - max_size=self.max_entries, - ) - - get_cache_key = self.cache_key_builder - sentinel = LruCacheDescriptor._Sentinel.sentinel - - @functools.wraps(self.orig) - def _wrapped(*args: Any, **kwargs: Any) -> Any: - invalidate_callback = kwargs.pop("on_invalidate", None) - callbacks = (invalidate_callback,) if invalidate_callback else () - - cache_key = get_cache_key(args, kwargs) - - ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) - if ret != sentinel: - return ret - - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - - ret2 = self.orig(obj, *args, **kwargs) - cache.set(cache_key, ret2, callbacks=callbacks) - - return ret2 - - wrapped = cast(CachedFunction, _wrapped) - wrapped.cache = cache - obj.__dict__[self.name] = wrapped - - return wrapped - - class DeferredCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 78fd7b6961..43475a307f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -28,7 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, cachedList, lru_cache +from synapse.util.caches.descriptors import cached, cachedList from tests import unittest from tests.test_utils import get_awaitable_result @@ -36,38 +36,6 @@ from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) -class LruCacheDecoratorTestCase(unittest.TestCase): - def test_base(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @lru_cache() - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - obj.mock.return_value = "fish" - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -478,10 +446,10 @@ class DescriptorTestCase(unittest.TestCase): @cached(cache_context=True) async def func2(self, key, cache_context): - return self.func3(key, on_invalidate=cache_context.invalidate) + return await self.func3(key, on_invalidate=cache_context.invalidate) - @lru_cache(cache_context=True) - def func3(self, key, cache_context): + @cached(cache_context=True) + async def func3(self, key, cache_context): self.invalidate = cache_context.invalidate return 42 -- cgit 1.5.1 From 2d0ba3f89aaf9545d81c4027500e543ec70b68a6 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 25 Oct 2022 13:38:01 +0000 Subject: Implementation for MSC3664: Pushrules for relations (#11804) --- changelog.d/11804.feature | 1 + rust/src/push/base_rules.rs | 17 +++ rust/src/push/evaluator.rs | 99 ++++++++++++- rust/src/push/mod.rs | 61 ++++++-- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 3 + synapse/push/bulk_push_rule_evaluator.py | 49 ++++++- synapse/rest/client/capabilities.py | 5 + synapse/storage/databases/main/push_rule.py | 15 +- tests/push/test_push_rule_evaluator.py | 215 +++++++++++++++++++++++++++- 10 files changed, 454 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11804.feature (limited to 'tests') diff --git a/changelog.d/11804.feature b/changelog.d/11804.feature new file mode 100644 index 0000000000..6420393541 --- /dev/null +++ b/changelog.d/11804.feature @@ -0,0 +1 @@ +Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 63240cacfc..49802fa4eb 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -25,6 +25,7 @@ use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; +use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; @@ -114,6 +115,22 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch( + RelatedEventMatchCondition { + key: Some(Cow::Borrowed("sender")), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + rel_type: Cow::Borrowed("m.in_reply_to"), + include_fallbacks: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0365dd01dc..cedd42c54d 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -23,6 +23,7 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, + RelatedEventMatchCondition, }; lazy_static! { @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator { /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, + + /// The related events, indexed by relation type. Flattened in the same manner as + /// `flattened_keys`. + related_events_flattened: BTreeMap>, + + /// If msc3664, push rules for related events, is enabled. + related_event_match_enabled: bool, } #[pymethods] @@ -60,6 +68,8 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, + related_events_flattened: BTreeMap>, + related_event_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -72,6 +82,8 @@ impl PushRuleEvaluator { room_member_count, notification_power_levels, sender_power_level, + related_events_flattened, + related_event_match_enabled, }) } @@ -156,6 +168,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::RelatedEventMatch(event_match) => { + self.match_related_event_match(event_match, user_id)? + } KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -239,6 +254,79 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `related_event_match` condition. (MSC3664) + fn match_related_event_match( + &self, + event_match: &RelatedEventMatchCondition, + user_id: Option<&str>, + ) -> Result { + // First check if related event matching is enabled... + if !self.related_event_match_enabled { + return Ok(false); + } + + // get the related event, fail if there is none. + let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) { + event + } else { + return Ok(false); + }; + + // If we are not matching fallbacks, don't match if our special key indicating this is a + // fallback relation is not present. + if !event_match.include_fallbacks.unwrap_or(false) + && event.contains_key("im.vector.is_falling_back") + { + return Ok(false); + } + + // if we have no key, accept the event as matching, if it existed without matching any + // fields. + let key = if let Some(key) = &event_match.key { + key + } else { + return Ok(true); + }; + + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -267,8 +355,15 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = - PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 0dabfab8b8..d57800aa4a 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -267,6 +267,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatch(RelatedEventMatchCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -299,6 +301,20 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::RelatedEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RelatedEventMatchCondition { + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern_type: Option>, + pub rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_fallbacks: Option, +} + /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] @@ -391,15 +407,21 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, + msc3664_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { + pub fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3664_enabled: bool, + ) -> Self { Self { push_rules, enabled_map, + msc3664_enabled, } } @@ -414,13 +436,25 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules.iter().map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3664_enabled + && rule.rule_id == "global/override/.im.nheko.msc3664.reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } @@ -446,6 +480,17 @@ fn test_deserialize_condition() { let _: Condition = serde_json::from_str(json).unwrap(); } +#[test] +fn test_deserialize_unstable_msc3664_condition() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RelatedEventMatch(_)) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index f2a61df660..f3b6d6c933 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,7 +25,9 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... + def __init__( + self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -37,6 +39,8 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], + related_events_flattened: Mapping[str, Mapping[str, str]], + related_event_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4009add01d..d9bdd66d55 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -98,6 +98,9 @@ class ExperimentalConfig(Config): # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) + # MSC3664: Pushrules to match on related events + self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False) + # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d7795a9080..75b7e126ca 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - push_rules_invalidation_counter = Counter( "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" ) @@ -107,6 +106,8 @@ class BulkPushRuleEvaluator: self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -218,6 +219,48 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation + + Returns: + Mapping of relation type to flattened events. + """ + related_events: Dict[str, Dict[str, str]] = {} + if self._related_event_match_enabled: + related_event_id = event.content.get("m.relates_to", {}).get("event_id") + relation_type = event.content.get("m.relates_to", {}).get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = ( + event.content.get("m.relates_to", {}) + .get("m.in_reply_to", {}) + .get("event_id") + ) + + # convert replies to pseudo relations + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) + + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) + + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" + + return related_events + async def action_for_events_by_user( self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: @@ -286,6 +329,8 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) + related_events = await self._related_events(event) + # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. notification_levels = power_levels.get("notifications", {}) @@ -298,6 +343,8 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, + related_events, + self._related_event_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 4237071c61..e84dde31b1 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -77,6 +77,11 @@ class CapabilitiesRestServlet(RestServlet): "enabled": True, } + if self.config.experimental.msc3664_enabled: + response["capabilities"]["im.nheko.msc3664.related_event_match"] = { + "enabled": self.config.experimental.msc3664_enabled, + } + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 51416b2236..b6c15f29f8 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,6 +29,7 @@ from typing import ( ) from synapse.api.errors import StoreError +from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -62,7 +63,9 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], enabled_map: Dict[str, bool] + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -80,7 +83,9 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules(push_rules, enabled_map) + filtered_rules = FilteredPushRules( + push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + ) return filtered_rules @@ -160,7 +165,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map) + return _load_rules(rows, enabled_map, self.hs.config.experimental) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -219,7 +224,9 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental + ) return results diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index decf619466..fe7c145840 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: + def _get_evaluator( + self, content: JsonDict, related_events=None + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), + {} if related_events is None else related_events, + True, ) def test_display_name(self) -> None: @@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) + def test_related_event_match(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.annotation", + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + "m.annotation": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.annotation", + "pattern": "@other_user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.replace", + }, + "@other_user:test", + "display_name", + ) + ) + + def test_related_event_match_with_fallback(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.thread", + "is_falling_back": True, + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + "im.vector.is_falling_back": "", + }, + "m.thread": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": True, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": False, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + + def test_related_event_match_no_related_event(self): + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Message without related event"} + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From 9192d74b0bf2f87b00d3e106a18baa9ce27acda1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 16:25:02 +0200 Subject: Refactor OIDC tests to better mimic an actual OIDC provider. (#13910) This implements a fake OIDC server, which intercepts calls to the HTTP client. Improves accuracy of tests by covering more internal methods. One particular example was the ID token validation, which previously mocked. This uncovered an incorrect dependency: Synapse actually requires at least authlib 0.15.1, not 0.14.0. --- changelog.d/13910.misc | 1 + pyproject.toml | 2 +- synapse/handlers/oidc.py | 15 +- tests/federation/test_federation_client.py | 36 +- tests/handlers/test_oidc.py | 580 +++++++++++++---------------- tests/rest/client/test_auth.py | 32 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 136 +++---- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 325 ++++++++++++++++ 10 files changed, 747 insertions(+), 460 deletions(-) create mode 100644 changelog.d/13910.misc create mode 100644 tests/test_utils/oidc.py (limited to 'tests') diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 0000000000..e906952aab --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. diff --git a/pyproject.toml b/pyproject.toml index 6ebac41ed1..7e0feb75aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..9759daf043 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,13 @@ class OidcProvider: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -688,9 +696,6 @@ class OidcProvider: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() @@ -715,7 +720,9 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index a538215931..51d3bb8fff 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions @@ -26,10 +23,9 @@ from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import event_injection +from tests.test_utils import FakeResponse, event_injection from tests.unittest import FederatingHomeserverTestCase @@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # We expect an outbound request to /backfill, so stub that out self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, # Mimic the other server returning our new `pulled_event` @@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase): # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the # other from "yet.another.server" self.assertEqual(backfill_num_attempts, 2) - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b7..5955410524 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ DEFAULT_CONFIG = { EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase): # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_server.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_server.get_metadata() + metadata.update(values) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_server.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_server.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_server.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_server.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_server.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase): """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase): """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase): Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 090cef5216..ebf653d018 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase): * checking that the original operation succeeds """ + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase): """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - login_resp = self.helper.login_via_oidc("username", expected_status=403) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, "username", expected_status=403 + ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) self.assertEqual( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e801ba8c8b..ff5baa9f0a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -36,7 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c249a42bb6..967d229223 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ from typing import ( Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,12 +553,28 @@ class RestHelper: return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( + clock=self.hs.get_clock(), + issuer=issuer, + ) + def login_via_oidc( self, + fake_server: FakeOidcServer, remote_user_id: str, + with_sid: bool = False, expected_status: int = 200, - ) -> JsonDict: - """Log in via OIDC + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: + """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -560,7 +586,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + fake_server, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -585,14 +614,16 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -616,6 +647,7 @@ class RestHelper: the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -625,14 +657,15 @@ class RestHelper: cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with fake_server.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -640,17 +673,21 @@ class RestHelper: # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,50 +698,37 @@ class RestHelper: Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = fake_serer.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -715,7 +739,7 @@ class RestHelper: ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -806,21 +830,3 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d..e62ebcc6a5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock: return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = attr.Factory(Headers) + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 0000000000..de134bbc89 --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,325 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcServer: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self._clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self._sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self._jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self._sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self._clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self._authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self._authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self._sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) -- cgit 1.5.1 From d902181de98399d90c46c4e4e2cf631064757941 Mon Sep 17 00:00:00 2001 From: James Salter Date: Tue, 25 Oct 2022 19:05:22 +0100 Subject: Unified search query syntax using the full-text search capabilities of the underlying DB. (#11635) Support a unified search query syntax which leverages more of the full-text search of each database supported by Synapse. Supports, with the same syntax across Postgresql 11+ and Sqlite: - quoted "search terms" - `AND`, `OR`, `-` (negation) operators - Matching words based on their stem, e.g. searches for "dog" matches documents containing "dogs". This is achieved by - If on postgresql 11+, pass the user input to `websearch_to_tsquery` - If on sqlite, manually parse the query and transform it into the sqlite-specific query syntax. Note that postgresql 10, which is close to end-of-life, falls back to using `phraseto_tsquery`, which only supports a subset of the features. Multiple terms separated by a space are implicitly ANDed. Note that: 1. There is no escaping of full-text syntax that might be supported by the database; e.g. `NOT`, `NEAR`, `*` in sqlite. This runs the risk that people might discover this as accidental functionality and depend on something we don't guarantee. 2. English text is assumed for stemming. To support other languages, either the target language needs to be known at the time of indexing the message (via room metadata, or otherwise), or a separate index for each language supported could be created. Sqlite docs: https://www.sqlite.org/fts3.html#full_text_index_queries Postgres docs: https://www.postgresql.org/docs/11/textsearch-controls.html --- changelog.d/11635.feature | 1 + synapse/storage/databases/main/search.py | 197 +++++++++++++++---- synapse/storage/engines/postgres.py | 16 ++ .../delta/73/10_update_sqlite_fts4_tokenizer.py | 62 ++++++ tests/storage/test_room_search.py | 213 +++++++++++++++++++++ 5 files changed, 454 insertions(+), 35 deletions(-) create mode 100644 changelog.d/11635.feature create mode 100644 synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py (limited to 'tests') diff --git a/changelog.d/11635.feature b/changelog.d/11635.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/11635.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1b79acf955..a89fc54c2c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -11,10 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple +from collections import deque +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr @@ -27,7 +39,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore): count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," " room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_for_sqlite(search_term) + sql = ( "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" " FROM event_search" @@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore): "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ?" ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore): ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - return { "results": [ {"event": event_map[r["event_id"]], "rank": r["rank"]} @@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore): Each match as a dictionary. """ clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore): args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # @@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore): " CROSS JOIN events USING (event_id)" " WHERE " ) + search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ? AND " ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase] + self, search_query: str, events: List[EventBase], tsquery_func: str ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events + tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( _to_postgres_options( { "StartSel": start_sel, @@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. +@dataclass +class Phrase: + phrase: List[str] + + +class SearchToken(enum.Enum): + Not = enum.auto() + Or = enum.auto() + And = enum.auto() + + +Token = Union[str, Phrase, SearchToken] +TokenList = List[Token] + + +def _is_stop_word(word: str) -> bool: + # TODO Pull these out of the dictionary: + # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop + return word in {"the", "a", "you", "me", "and", "but"} + + +def _tokenize_query(query: str) -> TokenList: + """ + Convert the user-supplied `query` into a TokenList, which can be translated into + some DB-specific syntax. + + The following constructs are supported: + + - phrase queries using "double quotes" + - case-insensitive `or` and `and` operators + - negation of a keyword via unary `-` + - unary hyphen to denote NOT e.g. 'include -exclude' + + The following differs from websearch_to_tsquery: + + - Stop words are not removed. + - Unclosed phrases are treated differently. + + """ + tokens: TokenList = [] + + # Find phrases. + in_phrase = False + parts = deque(query.split('"')) + for i, part in enumerate(parts): + # The contents inside double quotes is treated as a phrase, a trailing + # double quote is not implied. + in_phrase = bool(i % 2) and i != (len(parts) - 1) + + # Pull out the individual words, discarding any non-word characters. + words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) + + # Phrases have simplified handling of words. + if in_phrase: + # Skip stop words. + phrase = [word for word in words if not _is_stop_word(word)] + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the phrase. + tokens.append(Phrase(phrase)) + continue + + # Otherwise, not in a phrase. + while words: + word = words.popleft() + + if word.startswith("-"): + tokens.append(SearchToken.Not) + + # If there's more word, put it back to be processed again. + word = word[1:] + if word: + words.appendleft(word) + elif word.lower() == "or": + tokens.append(SearchToken.Or) + else: + # Skip stop words. + if _is_stop_word(word): + continue + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the search term. + tokens.append(word) + + return tokens + + +def _tokens_to_sqlite_match_query(tokens: TokenList) -> str: + """ + Convert the list of tokens to a string suitable for passing to sqlite's MATCH. + Assume sqlite was compiled with enhanced query syntax. + + Ref: https://www.sqlite.org/fts3.html#full_text_index_queries """ + match_query = [] + for token in tokens: + if isinstance(token, str): + match_query.append(token) + elif isinstance(token, Phrase): + match_query.append('"' + " ".join(token.phrase) + '"') + elif token == SearchToken.Not: + # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search + # term has already been added before this. + match_query.append(" NOT ") + elif token == SearchToken.Or: + match_query.append(" OR ") + elif token == SearchToken.And: + match_query.append(" AND ") + else: + raise ValueError(f"unknown token {token}") + + return "".join(match_query) - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") +def _parse_query_for_sqlite(search_term: str) -> str: + """Takes a plain unicode string from the user and converts it into a form + that can be passed to sqllite's matchinfo(). + """ + return _tokens_to_sqlite_match_query(_tokenize_query(search_term)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index d8c0f64d9a..9bf74bbf59 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,6 +170,22 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True + @property + def tsquery_func(self) -> str: + """ + Selects a tsquery_* func to use. + + Ref: https://www.postgresql.org/docs/current/textsearch-controls.html + + Returns: + The function name. + """ + # Postgres 11 added support for websearch_to_tsquery. + assert self._version is not None + if self._version >= 110000: + return "websearch_to_tsquery" + return "plainto_tsquery" + def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py new file mode 100644 index 0000000000..3de0a709eb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py @@ -0,0 +1,62 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: + """ + Upgrade the event_search table to use the porter tokenizer if it isn't already + + Applies only for sqlite. + """ + if not isinstance(database_engine, Sqlite3Engine): + return + + # Rebuild the table event_search table with tokenize=porter configured. + cur.execute("DROP TABLE event_search") + cur.execute( + """ + CREATE VIRTUAL TABLE event_search + USING fts4 (tokenize=porter, event_id, room_id, sender, key, value ) + """ + ) + + # Re-run the background job to re-populate the event_search table. + cur.execute("SELECT MIN(stream_ordering) FROM events") + row = cur.fetchone() + min_stream_id = row[0] + + # If there are not any events, nothing to do. + if min_stream_id is None: + return + + cur.execute("SELECT MAX(stream_ordering) FROM events") + row = cur.fetchone() + max_stream_id = row[0] + + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + } + progress_json = json.dumps(progress) + + sql = """ + INSERT into background_updates (ordering, update_name, progress_json) + VALUES (?, ?, ?) + """ + + cur.execute(sql, (7310, "event_search", progress_json)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index e747c6b50e..9ddc19900a 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple, Union +from unittest.case import SkipTest +from unittest.mock import PropertyMock, patch + +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.databases.main import DataStore +from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query from synapse.storage.engines import PostgresEngine +from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.util import Clock from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import USE_POSTGRES_FOR_TESTS @@ -187,3 +198,205 @@ class EventSearchInsertionTest(HomeserverTestCase): ), ) self.assertCountEqual(values, ["hi", "2"]) + + +class MessageSearchTest(HomeserverTestCase): + """ + Check message search. + + A powerful way to check the behaviour is to run the following in Postgres >= 11: + + # SELECT websearch_to_tsquery('english', ); + + The result can be compared to the tokenized version for SQLite and Postgres < 11. + + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + PHRASE = "the quick brown fox jumps over the lazy dog" + + # Each entry is a search query, followed by either a boolean of whether it is + # in the phrase OR a tuple of booleans: whether it matches using websearch + # and using plain search. + COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + ("nope", False), + ("brown", True), + ("quick brown", True), + ("brown quick", True), + ("quick \t brown", True), + ("jump", True), + ("brown nope", False), + ('"brown quick"', (False, True)), + ('"jumps over"', True), + ('"quick fox"', (False, True)), + ("nope OR doublenope", False), + ("furphy OR fox", (True, False)), + ("fox -nope", (True, False)), + ("fox -brown", (False, True)), + ('"fox" quick', True), + ('"fox quick', True), + ('"quick brown', True), + ('" quick "', True), + ('" nope"', False), + ] + # TODO Test non-ASCII cases. + + # Case that fail on SQLite. + POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # SQLite treats NOT as a binary operator. + ("- fox", (False, True)), + ("- nope", (True, False)), + ('"-fox quick', (False, True)), + # PostgreSQL skips stop words. + ('"the quick brown"', True), + ('"over lazy"', True), + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Register a user and create a room, create some messages + self.register_user("alice", "password") + self.access_token = self.login("alice", "password") + self.room_id = self.helper.create_room_as("alice", tok=self.access_token) + + # Send the phrase as a message and check it was created + response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) + self.assertIn("event_id", response) + + def test_tokenize_query(self) -> None: + """Test the custom logic to tokenize a user's query.""" + cases = ( + ("brown", ["brown"]), + ("quick brown", ["quick", SearchToken.And, "brown"]), + ("quick \t brown", ["quick", SearchToken.And, "brown"]), + ('"brown quick"', [Phrase(["brown", "quick"])]), + ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]), + ("fox -brown", ["fox", SearchToken.Not, "brown"]), + ("- fox", [SearchToken.Not, "fox"]), + ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), + # No trailing double quoe. + ('"fox quick', ["fox", SearchToken.And, "quick"]), + ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + ('" quick "', [Phrase(["quick"])]), + ( + 'q"uick brow"n', + [ + "q", + SearchToken.And, + Phrase(["uick", "brow"]), + SearchToken.And, + "n", + ], + ), + ( + '-"quick brown"', + [SearchToken.Not, Phrase(["quick", "brown"])], + ), + ) + + for query, expected in cases: + tokenized = _tokenize_query(query) + self.assertEqual( + tokenized, expected, f"{tokenized} != {expected} for {query}" + ) + + def _check_test_cases( + self, + store: DataStore, + cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], + index=0, + ) -> None: + # Run all the test cases versus search_msgs + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_msgs([self.room_id], query, ["content.body"]) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + # Run them again versus search_rooms + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_rooms([self.room_id], query, ["content.body"], 10) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + def test_postgres_web_search_for_phrase(self): + """ + Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. + This test is skipped unless the postgres instance supports websearch_to_tsquery. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + if store.database_engine.tsquery_func != "websearch_to_tsquery": + raise SkipTest( + "Test only applies when postgres supporting websearch_to_tsquery is used as the database" + ) + + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) + + def test_postgres_non_web_search_for_phrase(self): + """ + Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't + supported by the current postgres version. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. + with patch( + "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", + new_callable=PropertyMock, + ) as supports_websearch_to_tsquery: + supports_websearch_to_tsquery.return_value = "plainto_tsquery" + self._check_test_cases( + store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 + ) + + def test_sqlite_search(self): + """ + Test sqlite searching for phrases. + """ + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, Sqlite3Engine): + raise SkipTest("Test only applies when sqlite is used as the database") + + self._check_test_cases(store, self.COMMON_CASES, index=0) -- cgit 1.5.1 From 8756d5c87efc5637da55c9e21d2a4eb2369ba693 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 26 Oct 2022 12:45:41 +0200 Subject: Save login tokens in database (#13844) * Save login tokens in database Signed-off-by: Quentin Gliech * Add upgrade notes * Track login token reuse in a Prometheus metric Signed-off-by: Quentin Gliech --- changelog.d/13844.misc | 1 + docs/upgrade.md | 9 ++ synapse/handlers/auth.py | 64 +++++++-- synapse/module_api/__init__.py | 41 +----- synapse/rest/client/login.py | 3 +- synapse/rest/client/login_token_request.py | 5 +- synapse/storage/databases/main/registration.py | 156 ++++++++++++++++++++- .../schema/main/delta/73/10login_tokens.sql | 35 +++++ synapse/util/macaroons.py | 87 +----------- tests/handlers/test_auth.py | 135 ++++++++++-------- tests/util/test_macaroons.py | 28 ---- 11 files changed, 337 insertions(+), 227 deletions(-) create mode 100644 changelog.d/13844.misc create mode 100644 synapse/storage/schema/main/delta/73/10login_tokens.sql (limited to 'tests') diff --git a/changelog.d/13844.misc b/changelog.d/13844.misc new file mode 100644 index 0000000000..66f4414df7 --- /dev/null +++ b/changelog.d/13844.misc @@ -0,0 +1 @@ +Save login tokens in database and prevent login token reuse. diff --git a/docs/upgrade.md b/docs/upgrade.md index b81385b191..78c34d0c15 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.71.0 + +## Removal of the `generate_short_term_login_token` module API method + +As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed. + +Modules relying on it can instead use the `create_login_token` method. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f5f0e0e7a7..8b9ef25d29 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -38,6 +38,7 @@ from typing import ( import attr import bcrypt import unpaddedbase64 +from prometheus_client import Counter from twisted.internet.defer import CancelledError from twisted.web.server import Request @@ -48,6 +49,7 @@ from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, LoginError, + NotFoundError, StoreError, SynapseError, UserDeactivatedError, @@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.registration import ( + LoginTokenExpired, + LoginTokenLookupResult, + LoginTokenReused, +) from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable -from synapse.util.macaroons import LoginTokenAttributes from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email @@ -80,6 +86,12 @@ logger = logging.getLogger(__name__) INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" +invalid_login_token_counter = Counter( + "synapse_user_login_invalid_login_tokens", + "Counts the number of rejected m.login.token on /login", + ["reason"], +) + def convert_client_dict_legacy_fields_to_identifier( submission: JsonDict, @@ -883,6 +895,25 @@ class AuthHandler: return True + async def create_login_token_for_user_id( + self, + user_id: str, + duration_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + login_token = self.generate_login_token() + now = self._clock.time_msec() + expiry_ts = now + duration_ms + await self.store.add_login_token_to_user( + user_id=user_id, + token=login_token, + expiry_ts=expiry_ts, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + return login_token + async def create_refresh_token_for_user_id( self, user_id: str, @@ -1401,6 +1432,18 @@ class AuthHandler: return None return user_id + def generate_login_token(self) -> str: + """Generates an opaque string, for use as an short-term login token""" + + # we use the following format for access tokens: + # syl__ + + random_string = stringutils.random_string(20) + base = f"syl_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + def generate_access_token(self, for_user: UserID) -> str: """Generates an opaque string, for use as an access token""" @@ -1427,16 +1470,17 @@ class AuthHandler: crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" - async def validate_short_term_login_token( - self, login_token: str - ) -> LoginTokenAttributes: + async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult: try: - res = self.macaroon_gen.verify_short_term_login_token(login_token) - except Exception: - raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) + return await self.store.consume_login_token(login_token) + except LoginTokenExpired: + invalid_login_token_counter.labels("expired").inc() + except LoginTokenReused: + invalid_login_token_counter.labels("reused").inc() + except NotFoundError: + invalid_login_token_counter.labels("not found").inc() - await self.auth_blocking.check_auth_blocking(res.user_id) - return res + raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token @@ -1711,7 +1755,7 @@ class AuthHandler: ) # Create a login token - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.create_login_token_for_user_id( registered_user_id, auth_provider_id=auth_provider_id, auth_provider_session_id=auth_provider_session_id, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6a6ae208d1..30e689d00d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -771,50 +771,11 @@ class ModuleApi: auth_provider_session_id: The session ID got during login from the SSO IdP, if any. """ - # The deprecated `generate_short_term_login_token` method defaulted to an empty - # string for the `auth_provider_id` because of how the underlying macaroon was - # generated. This will change to a proper NULL-able field when the tokens get - # moved to the database. - return self._hs.get_macaroon_generator().generate_short_term_login_token( + return await self._hs.get_auth_handler().create_login_token_for_user_id( user_id, - auth_provider_id or "", - auth_provider_session_id, duration_in_ms, - ) - - def generate_short_term_login_token( - self, - user_id: str, - duration_in_ms: int = (2 * 60 * 1000), - auth_provider_id: str = "", - auth_provider_session_id: Optional[str] = None, - ) -> str: - """Generate a login token suitable for m.login.token authentication - - Added in Synapse v1.9.0. - - This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will - be removed in Synapse 1.71.0. - - Args: - user_id: gives the ID of the user that the token is for - - duration_in_ms: the time that the token will be valid for - - auth_provider_id: the ID of the SSO IdP that the user used to authenticate - to get this token, if any. This is encoded in the token so that - /login can report stats on number of successful logins by IdP. - """ - logger.warn( - "A module configured on this server uses ModuleApi.generate_short_term_login_token(), " - "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in " - "Synapse 1.71.0", - ) - return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, auth_provider_id, auth_provider_session_id, - duration_in_ms, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f554586ac3..7774f1967d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet): The body of the JSON response. """ token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) + res = await self.auth_handler.consume_login_token(token) return await self._complete_login( res.user_id, diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index 277b20fb63..43ea21d5e6 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet): self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name - self.macaroon_gen = hs.get_macaroon_generator() self.auth_handler = hs.get_auth_handler() self.token_timeout = hs.config.experimental.msc3882_token_timeout self.ui_auth = hs.config.experimental.msc3882_ui_auth @@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet): can_skip_ui_auth=False, # Don't allow skipping of UI auth ) - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.auth_handler.create_login_token_for_user_id( user_id=requester.user.to_string(), auth_provider_id="org.matrix.msc3882.login_token_request", - duration_in_ms=self.token_timeout, + duration_ms=self.token_timeout, ) return ( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 2996d6bb4d..0255295317 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + NotFoundError, + StoreError, + SynapseError, + ThreepidValidationError, +) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception): because this external id is given to an other user.""" +class LoginTokenExpired(Exception): + """Exception if the login token sent expired""" + + +class LoginTokenReused(Exception): + """Exception if the login token sent was already used""" + + @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -115,6 +129,20 @@ class RefreshTokenLookupResult: If None, the session can be refreshed indefinitely.""" +@attr.s(auto_attribs=True, frozen=True, slots=True) +class LoginTokenLookupResult: + """Result of looking up a login token.""" + + user_id: str + """The user this token belongs to.""" + + auth_provider_id: Optional[str] + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id: Optional[str] + """The session ID advertised by the SSO Identity Provider.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + async def add_login_token_to_user( + self, + user_id: str, + token: str, + expiry_ts: int, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> None: + """Adds a short-term login token for the given user. + + Args: + user_id: The user ID. + token: The new login token to add. + expiry_ts (milliseconds since the epoch): Time after which the login token + cannot be used. + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token, if any + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider, if any. + """ + await self.db_pool.simple_insert( + "login_tokens", + { + "token": token, + "user_id": user_id, + "expiry_ts": expiry_ts, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="add_login_token_to_user", + ) + + def _consume_login_token( + self, + txn: LoggingTransaction, + token: str, + ts: int, + ) -> LoginTokenLookupResult: + values = self.db_pool.simple_select_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + retcols=( + "user_id", + "expiry_ts", + "used_ts", + "auth_provider_id", + "auth_provider_session_id", + ), + allow_none=True, + ) + + if values is None: + raise NotFoundError() + + self.db_pool.simple_update_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + updatevalues={"used_ts": ts}, + ) + user_id = values["user_id"] + expiry_ts = values["expiry_ts"] + used_ts = values["used_ts"] + auth_provider_id = values["auth_provider_id"] + auth_provider_session_id = values["auth_provider_session_id"] + + # Token was already used + if used_ts is not None: + raise LoginTokenReused() + + # Token expired + if ts > int(expiry_ts): + raise LoginTokenExpired() + + return LoginTokenLookupResult( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + async def consume_login_token(self, token: str) -> LoginTokenLookupResult: + """Lookup a login token and consume it. + + Args: + token: The login token. + + Returns: + The data stored with that token, including the `user_id`. Returns `None` if + the token does not exist or if it expired. + + Raises: + NotFound if the login token was not found in database + LoginTokenExpired if the login token expired + LoginTokenReused if the login token was already used + """ + return await self.db_pool.runInteraction( + "consume_login_token", + self._consume_login_token, + token, + self._clock.time_msec(), + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( @@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Create a background job for removing expired login tokens + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS + ) + async def add_access_token_to_user( self, user_id: str, @@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): approved, ) + @wrap_as_background_process("delete_expired_login_tokens") + async def _delete_expired_login_tokens(self) -> None: + """Remove login tokens with expiry dates that have passed.""" + + def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?" + txn.execute(sql, (ts,)) + + # We keep the expired tokens for an extra 5 minutes so we can measure how many + # times a token is being used after its expiry + now = self._clock.time_msec() + await self.db_pool.runInteraction( + "delete_expired_login_tokens", + _delete_expired_login_tokens_txn, + now - (5 * 60 * 1000), + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql new file mode 100644 index 0000000000..a39b7bcece --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql @@ -0,0 +1,35 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Login tokens are short-lived tokens that are used for the m.login.token +-- login method, mainly during SSO logins +CREATE TABLE login_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expiry_ts BIGINT NOT NULL, + used_ts BIGINT, + auth_provider_id TEXT, + auth_provider_session_id TEXT +); + +-- We're sometimes querying them by their session ID we got from their IDP +CREATE INDEX login_tokens_auth_provider_idx + ON login_tokens (auth_provider_id, auth_provider_session_id); + +-- We're deleting them by their expiration time +CREATE INDEX login_tokens_expiry_time_idx + ON login_tokens (expiry_ts); + diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index df77edcce2..5df03d3ddc 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from synapse.util import Clock, stringutils -MacaroonType = Literal["access", "delete_pusher", "session", "login"] +MacaroonType = Literal["access", "delete_pusher", "session"] def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: @@ -111,19 +111,6 @@ class OidcSessionData: """The session ID of the ongoing UI Auth ("" if this is a login)""" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class LoginTokenAttributes: - """Data we store in a short-term login token""" - - user_id: str - - auth_provider_id: str - """The SSO Identity Provider that the user authenticated with, to get this token.""" - - auth_provider_session_id: Optional[str] - """The session ID advertised by the SSO Identity Provider.""" - - class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): self._clock = clock @@ -165,35 +152,6 @@ class MacaroonGenerator: macaroon.add_first_party_caveat(f"pushkey = {pushkey}") return macaroon.serialize() - def generate_short_term_login_token( - self, - user_id: str, - auth_provider_id: str, - auth_provider_session_id: Optional[str] = None, - duration_in_ms: int = (2 * 60 * 1000), - ) -> str: - """Generate a short-term login token used during SSO logins - - Args: - user_id: The user for which the token is valid. - auth_provider_id: The SSO IdP the user used. - auth_provider_session_id: The session ID got during login from the SSO IdP. - - Returns: - A signed token valid for using as a ``m.login.token`` token. - """ - now = self._clock.time_msec() - expiry = now + duration_in_ms - macaroon = self._generate_base_macaroon("login") - macaroon.add_first_party_caveat(f"user_id = {user_id}") - macaroon.add_first_party_caveat(f"time < {expiry}") - macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}") - if auth_provider_session_id is not None: - macaroon.add_first_party_caveat( - f"auth_provider_session_id = {auth_provider_session_id}" - ) - return macaroon.serialize() - def generate_oidc_session_token( self, state: str, @@ -233,49 +191,6 @@ class MacaroonGenerator: return macaroon.serialize() - def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: - """Verify a short-term-login macaroon - - Checks that the given token is a valid, unexpired short-term-login token - minted by this server. - - Args: - token: The login token to verify. - - Returns: - A set of attributes carried by this token, including the - ``user_id`` and informations about the SSO IDP used during that - login. - - Raises: - MacaroonVerificationFailedException if the verification failed - """ - macaroon = pymacaroons.Macaroon.deserialize(token) - - v = self._base_verifier("login") - v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) - satisfy_expiry(v, self._clock.time_msec) - v.verify(macaroon, self._secret_key) - - user_id = get_value_from_macaroon(macaroon, "user_id") - auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") - - auth_provider_session_id: Optional[str] = None - try: - auth_provider_session_id = get_value_from_macaroon( - macaroon, "auth_provider_session_id" - ) - except MacaroonVerificationFailedException: - pass - - return LoginTokenAttributes( - user_id=user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - def verify_guest_token(self, token: str) -> str: """Verify a guest access token macaroon diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 7106799d44..036dbbc45b 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest.mock import Mock import pymacaroons @@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, + login.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") + def token_login(self, token: str) -> Optional[str]: + body = { + "type": "m.login.token", + "token": token, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + + if channel.code == 200: + return channel.json_body["user_id"] + + return None + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + def test_login_token_gives_user_id(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + + res = self.get_success(self.auth_handler.consume_login_token(token)) self.assertEqual(self.user1, res.user_id) - self.assertEqual("", res.auth_provider_id) + self.assertEqual(None, res.auth_provider_id) - # when we advance the clock, the token should be rejected - self.reactor.advance(6) - self.get_failure( - self.auth_handler.validate_short_term_login_token(token), - AuthError, + def test_login_token_reuse_fails(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - def test_short_term_login_token_gives_auth_provider(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, auth_provider_id="my_idp" - ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) - self.assertEqual(self.user1, res.user_id) - self.assertEqual("my_idp", res.auth_provider_id) + self.get_success(self.auth_handler.consume_login_token(token)) - def test_short_term_login_token_cannot_replace_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.get_failure( + self.auth_handler.consume_login_token(token), + AuthError, ) - macaroon = pymacaroons.Macaroon.deserialize(token) - res = self.get_success( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()) + def test_login_token_expires(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - self.assertEqual(self.user1, res.user_id) - - # add another "user_id" caveat, which might allow us to override the - # user_id. - macaroon.add_first_party_caveat("user_id = b_user") + # when we advance the clock, the token should be rejected + self.reactor.advance(6) self.get_failure( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()), + self.auth_handler.consume_login_token(token), AuthError, ) + def test_login_token_gives_auth_provider(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + auth_provider_id="my_idp", + auth_provider_session_id="11-22-33-44", + duration_ms=(5 * 1000), + ) + ) + res = self.get_success(self.auth_handler.consume_login_token(token)) + self.assertEqual(self.user1, res.user_id) + self.assertEqual("my_idp", res.auth_provider_id) + self.assertEqual("11-22-33-44", res.auth_provider_session_id) + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception @@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. @@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase): ), ResourceLimitError, ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) # If in monthly active cohort self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( @@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1, device_id=None, valid_until_ms=None ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True @@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) - ) - - def _get_macaroon(self) -> pymacaroons.Macaroon: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) - return pymacaroons.Macaroon.deserialize(token) + self.assertIsNotNone(self.token_login(token)) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 32125f7bb7..40754a4711 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_short_term_login_token(self): - """Test the generation and verification of short-term login tokens""" - token = self.macaroon_generator.generate_short_term_login_token( - user_id="@user:tesths", - auth_provider_id="oidc", - auth_provider_session_id="sid", - duration_in_ms=2 * 60 * 1000, - ) - - info = self.macaroon_generator.verify_short_term_login_token(token) - self.assertEqual(info.user_id, "@user:tesths") - self.assertEqual(info.auth_provider_id, "oidc") - self.assertEqual(info.auth_provider_session_id, "sid") - - # Raises with another secret key - with self.assertRaises(MacaroonVerificationFailedException): - self.other_macaroon_generator.verify_short_term_login_token(token) - - # Wait a minute - self.reactor.pump([60]) - # Shouldn't raise - self.macaroon_generator.verify_short_term_login_token(token) - # Wait another minute - self.reactor.pump([60]) - # Should raise since it expired - with self.assertRaises(MacaroonVerificationFailedException): - self.macaroon_generator.verify_short_term_login_token(token) - def test_oidc_session_token(self): """Test the generation and verification of OIDC session cookies""" state = "arandomstate" -- cgit 1.5.1 From 04fd6221de026a74e8a3e896796d39dcf5ac6e3b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 26 Oct 2022 14:00:01 +0100 Subject: Fix incorrectly sending authentication tokens to application service as headers (#14301) --- changelog.d/14301.bugfix | 1 + synapse/appservice/api.py | 12 +++++++----- tests/appservice/test_api.py | 8 +++++--- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14301.bugfix (limited to 'tests') diff --git a/changelog.d/14301.bugfix b/changelog.d/14301.bugfix new file mode 100644 index 0000000000..668c1f3b9c --- /dev/null +++ b/changelog.d/14301.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where access tokens would be incorrectly sent to application services as headers. Application services which were obtaining access tokens from query parameters were not affected. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index fbac4375b0..60774b240d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -123,7 +123,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if response is not None: # just an empty json object return True @@ -147,7 +147,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if response is not None: # just an empty json object return True @@ -190,7 +190,9 @@ class ApplicationServiceApi(SimpleHttpClient): b"access_token": service.hs_token, } response = await self.get_json( - uri, args=args, headers={"Authorization": f"Bearer {service.hs_token}"} + uri, + args=args, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if not isinstance(response, list): logger.warning( @@ -230,7 +232,7 @@ class ApplicationServiceApi(SimpleHttpClient): info = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if not _is_valid_3pe_metadata(info): @@ -327,7 +329,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri=uri, json_body=body, args={"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 11008ac1fb..89ee79396f 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Mapping +from typing import Any, List, Mapping, Sequence, Union from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -70,13 +70,15 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): self.request_url = None async def get_json( - url: str, args: Mapping[Any, Any], headers: Mapping[Any, Any] + url: str, + args: Mapping[Any, Any], + headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> List[JsonDict]: # Ensure the access token is passed as both a header and query arg. if not headers.get("Authorization") or not args.get(b"access_token"): raise RuntimeError("Access token not provided") - self.assertEqual(headers.get("Authorization"), f"Bearer {TOKEN}") + self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) self.assertEqual(args.get(b"access_token"), TOKEN) self.request_url = url if url == URL_USER: -- cgit 1.5.1 From 0cfbb3513152b8360155c2d75df50e06ea861fa4 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Wed, 26 Oct 2022 18:51:23 +0400 Subject: fix broken avatar checks when server_name contains a port (#13927) Fixes check_avatar_size_and_mime_type() to successfully update avatars on homeservers running on non-default ports which it would mistakenly treat as remote homeserver while validating the avatar's size and mime type. Signed-off-by: Ashish Kumar ashfame@users.noreply.github.com --- changelog.d/13927.bugfix | 1 + synapse/handlers/profile.py | 6 +++++- tests/handlers/test_profile.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13927.bugfix (limited to 'tests') diff --git a/changelog.d/13927.bugfix b/changelog.d/13927.bugfix new file mode 100644 index 0000000000..119cd128e7 --- /dev/null +++ b/changelog.d/13927.bugfix @@ -0,0 +1 @@ +Fix a bug which prevented setting an avatar on homeservers which have an explicit port in their `server_name` and have `max_avatar_size` and/or `allowed_avatar_mimetypes` configuration. Contributed by @ashfame. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d8ff5289b5..4bf9a047a3 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -307,7 +307,11 @@ class ProfileHandler: if not self.max_avatar_size and not self.allowed_avatar_mimetypes: return True - server_name, _, media_id = parse_and_validate_mxc_uri(mxc) + host, port, media_id = parse_and_validate_mxc_uri(mxc) + if port is not None: + server_name = host + ":" + str(port) + else: + server_name = host if server_name == self.server_name: media_info = await self.store.get_local_media(media_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f88c725a42..675aa023ac 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,6 +14,8 @@ from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor import synapse.types @@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertFalse(res) + @unittest.override_config( + {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} + ) + def test_avatar_constraint_on_local_server_with_port(self): + """Test that avatar metadata is correctly fetched when the media is on a local + server and the server has an explicit port. + + (This was previously a bug) + """ + local_server_name = self.hs.config.server.server_name + media_id = "local" + local_mxc = f"mxc://{local_server_name}/{media_id}" + + # mock up the existence of the avatar file + self._setup_local_files({media_id: {"mimetype": "image/png"}}) + + # and now check that check_avatar_size_and_mime_type is happy + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc)) + ) + + @parameterized.expand([("remote",), ("remote:1234",)]) + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None: + """Test that avatar metadata is correctly fetched from a remote server""" + media_id = "remote" + remote_mxc = f"mxc://{remote_server_name}/{media_id}" + + # if the media is remote, check_avatar_size_and_mime_type just checks the + # media cache, so we don't need to instantiate a real remote server. It is + # sufficient to poke an entry into the db. + self.get_success( + self.hs.get_datastores().main.store_cached_remote_media( + media_id=media_id, + media_type="image/png", + media_length=50, + origin=remote_server_name, + time_now_ms=self.clock.time_msec(), + upload_name=None, + filesystem_id="xyz", + ) + ) + + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) + ) + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): """Stores metadata about files in the database. -- cgit 1.5.1 From 40fa8294e3096132819287dd0c6d6bd71a408902 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 26 Oct 2022 16:10:55 -0500 Subject: Refactor MSC3030 `/timestamp_to_event` to move away from our snowflake pull from `destination` pattern (#14096) 1. `federation_client.timestamp_to_event(...)` now handles all `destination` looping and uses our generic `_try_destination_list(...)` helper. 2. Consistently handling `NotRetryingDestination` and `FederationDeniedError` across `get_pdu` , backfill, and the generic `_try_destination_list` which is used for many places we use this pattern. 3. `get_pdu(...)` now returns `PulledPduInfo` so we know which `destination` we ended up pulling the PDU from --- changelog.d/14096.misc | 1 + synapse/federation/federation_client.py | 130 ++++++++++++++++++++++++----- synapse/handlers/federation.py | 15 ++-- synapse/handlers/federation_event.py | 31 ++++--- synapse/handlers/room.py | 126 +++++++++++----------------- synapse/util/retryutils.py | 2 +- tests/federation/test_federation_client.py | 12 ++- 7 files changed, 191 insertions(+), 126 deletions(-) create mode 100644 changelog.d/14096.misc (limited to 'tests') diff --git a/changelog.d/14096.misc b/changelog.d/14096.misc new file mode 100644 index 0000000000..2c07dc673b --- /dev/null +++ b/changelog.d/14096.misc @@ -0,0 +1 @@ +Refactor [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint to loop over federation destinations with standard pattern and error handling. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b220ab43fc..fa225182be 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -80,6 +80,18 @@ PDU_RETRY_TIME_MS = 1 * 60 * 1000 T = TypeVar("T") +@attr.s(frozen=True, slots=True, auto_attribs=True) +class PulledPduInfo: + """ + A result object that stores the PDU and info about it like which homeserver we + pulled it from (`pull_origin`) + """ + + pdu: EventBase + # Which homeserver we pulled the PDU from + pull_origin: str + + class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse @@ -114,7 +126,9 @@ class FederationClient(FederationBase): self.hostname = hs.hostname self.signing_key = hs.signing_key - self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache( + # Cache mapping `event_id` to a tuple of the event itself and the `pull_origin` + # (which server we pulled the event from) + self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache( cache_name="get_pdu_cache", clock=self._clock, max_len=1000, @@ -352,11 +366,11 @@ class FederationClient(FederationBase): @tag_args async def get_pdu( self, - destinations: Iterable[str], + destinations: Collection[str], event_id: str, room_version: RoomVersion, timeout: Optional[int] = None, - ) -> Optional[EventBase]: + ) -> Optional[PulledPduInfo]: """Requests the PDU with given origin and ID from the remote home servers. @@ -371,11 +385,11 @@ class FederationClient(FederationBase): moving to the next destination. None indicates no timeout. Returns: - The requested PDU, or None if we were unable to find it. + The requested PDU wrapped in `PulledPduInfo`, or None if we were unable to find it. """ logger.debug( - "get_pdu: event_id=%s from destinations=%s", event_id, destinations + "get_pdu(event_id=%s): from destinations=%s", event_id, destinations ) # TODO: Rate limit the number of times we try and get the same event. @@ -384,19 +398,25 @@ class FederationClient(FederationBase): # it gets persisted to the database), so we cache the results of the lookup. # Note that this is separate to the regular get_event cache which caches # events once they have been persisted. - event = self._get_pdu_cache.get(event_id) + get_pdu_cache_entry = self._get_pdu_cache.get(event_id) + event = None + pull_origin = None + if get_pdu_cache_entry: + event, pull_origin = get_pdu_cache_entry # If we don't see the event in the cache, go try to fetch it from the # provided remote federated destinations - if not event: + else: pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + # TODO: We can probably refactor this to use `_try_destination_list` for destination in destinations: now = self._clock.time_msec() last_attempt = pdu_attempts.get(destination, 0) if last_attempt + PDU_RETRY_TIME_MS > now: logger.debug( - "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + "get_pdu(event_id=%s): skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + event_id, destination, last_attempt, PDU_RETRY_TIME_MS, @@ -411,43 +431,48 @@ class FederationClient(FederationBase): room_version=room_version, timeout=timeout, ) + pull_origin = destination pdu_attempts[destination] = now if event: # Prime the cache - self._get_pdu_cache[event.event_id] = event + self._get_pdu_cache[event.event_id] = (event, pull_origin) # Now that we have an event, we can break out of this # loop and stop asking other destinations. break + except NotRetryingDestination as e: + logger.info("get_pdu(event_id=%s): %s", event_id, e) + continue + except FederationDeniedError: + logger.info( + "get_pdu(event_id=%s): Not attempting to fetch PDU from %s because the homeserver is not on our federation whitelist", + event_id, + destination, + ) + continue except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", + "get_pdu(event_id=%s): Failed to get PDU from %s because %s", event_id, destination, e, ) continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue - except FederationDeniedError as e: - logger.info(str(e)) - continue except Exception as e: pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", + "get_pdu(event_id=): Failed to get PDU from %s because %s", event_id, destination, e, ) continue - if not event: + if not event or not pull_origin: return None # `event` now refers to an object stored in `get_pdu_cache`. Our @@ -459,7 +484,7 @@ class FederationClient(FederationBase): event.room_version, ) - return event_copy + return PulledPduInfo(event_copy, pull_origin) @trace @tag_args @@ -699,12 +724,14 @@ class FederationClient(FederationBase): pdu_origin = get_domain_from_id(pdu.sender) if not res and pdu_origin != origin: try: - res = await self.get_pdu( + pulled_pdu_info = await self.get_pdu( destinations=[pdu_origin], event_id=pdu.event_id, room_version=room_version, timeout=10000, ) + if pulled_pdu_info is not None: + res = pulled_pdu_info.pdu except SynapseError: pass @@ -806,6 +833,7 @@ class FederationClient(FederationBase): ) for destination in destinations: + # We don't want to ask our own server for information we don't have if destination == self.server_name: continue @@ -814,9 +842,21 @@ class FederationClient(FederationBase): except ( RequestSendFailed, InvalidResponseError, - NotRetryingDestination, ) as e: logger.warning("Failed to %s via %s: %s", description, destination, e) + # Skip to the next homeserver in the list to try. + continue + except NotRetryingDestination as e: + logger.info("%s: %s", description, e) + continue + except FederationDeniedError: + logger.info( + "%s: Not attempting to %s from %s because the homeserver is not on our federation whitelist", + description, + description, + destination, + ) + continue except UnsupportedRoomVersionError: raise except HttpResponseException as e: @@ -1609,6 +1649,54 @@ class FederationClient(FederationBase): return result async def timestamp_to_event( + self, *, destinations: List[str], room_id: str, timestamp: int, direction: str + ) -> Optional["TimestampToEventResponse"]: + """ + Calls each remote federating server from `destinations` asking for their closest + event to the given timestamp in the given direction until we get a response. + Also validates the response to always return the expected keys or raises an + error. + + Args: + destinations: The domains of homeservers to try fetching from + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A parsed TimestampToEventResponse including the closest event_id + and origin_server_ts or None if no destination has a response. + """ + + async def _timestamp_to_event_from_destination( + destination: str, + ) -> TimestampToEventResponse: + return await self._timestamp_to_event_from_destination( + destination, room_id, timestamp, direction + ) + + try: + # Loop through each homeserver candidate until we get a succesful response + timestamp_to_event_response = await self._try_destination_list( + "timestamp_to_event", + destinations, + # TODO: The requested timestamp may lie in a part of the + # event graph that the remote server *also* didn't have, + # in which case they will have returned another event + # which may be nowhere near the requested timestamp. In + # the future, we may need to reconcile that gap and ask + # other homeservers, and/or extend `/timestamp_to_event` + # to return events on *both* sides of the timestamp to + # help reconcile the gap faster. + _timestamp_to_event_from_destination, + ) + return timestamp_to_event_response + except SynapseError: + return None + + async def _timestamp_to_event_from_destination( self, destination: str, room_id: str, timestamp: int, direction: str ) -> "TimestampToEventResponse": """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4fbc79a6cb..5fc3b8bc8c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -442,6 +442,15 @@ class FederationHandler: # appropriate stuff. # TODO: We can probably do something more intelligent here. return True + except NotRetryingDestination as e: + logger.info("_maybe_backfill_inner: %s", e) + continue + except FederationDeniedError: + logger.info( + "_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist", + dom, + ) + continue except (SynapseError, InvalidResponseError) as e: logger.info("Failed to backfill from %s because %s", dom, e) continue @@ -477,15 +486,9 @@ class FederationHandler: logger.info("Failed to backfill from %s because %s", dom, e) continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue except RequestSendFailed as e: logger.info("Failed to get backfill from %s because %s", dom, e) continue - except FederationDeniedError as e: - logger.info(e) - continue except Exception as e: logger.exception("Failed to backfill from %s because %s", dom, e) continue diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7da6316a82..9ca5df7c78 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -58,7 +58,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.federation.federation_client import InvalidResponseError +from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import ( SynapseTags, @@ -1517,8 +1517,8 @@ class FederationEventHandler: ) async def backfill_event_id( - self, destination: str, room_id: str, event_id: str - ) -> EventBase: + self, destinations: List[str], room_id: str, event_id: str + ) -> PulledPduInfo: """Backfill a single event and persist it as a non-outlier which means we also pull in all of the state and auth events necessary for it. @@ -1530,24 +1530,21 @@ class FederationEventHandler: Raises: FederationError if we are unable to find the event from the destination """ - logger.info( - "backfill_event_id: event_id=%s from destination=%s", event_id, destination - ) + logger.info("backfill_event_id: event_id=%s", event_id) room_version = await self._store.get_room_version(room_id) - event_from_response = await self._federation_client.get_pdu( - [destination], + pulled_pdu_info = await self._federation_client.get_pdu( + destinations, event_id, room_version, ) - if not event_from_response: + if not pulled_pdu_info: raise FederationError( "ERROR", 404, - "Unable to find event_id=%s from destination=%s to backfill." - % (event_id, destination), + f"Unable to find event_id={event_id} from remote servers to backfill.", affected=event_id, ) @@ -1555,13 +1552,13 @@ class FederationEventHandler: # and auth events to de-outlier it. This also sets up the necessary # `state_groups` for the event. await self._process_pulled_events( - destination, - [event_from_response], + pulled_pdu_info.pull_origin, + [pulled_pdu_info.pdu], # Prevent notifications going to clients backfilled=True, ) - return event_from_response + return pulled_pdu_info @trace @tag_args @@ -1584,19 +1581,19 @@ class FederationEventHandler: async def get_event(event_id: str) -> None: with nested_logging_context(event_id): try: - event = await self._federation_client.get_pdu( + pulled_pdu_info = await self._federation_client.get_pdu( [destination], event_id, room_version, ) - if event is None: + if pulled_pdu_info is None: logger.warning( "Server %s didn't return event %s", destination, event_id, ) return - events.append(event) + events.append(pulled_pdu_info.pdu) except Exception as e: logger.warning( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index cc1e5c8f97..de97886ea9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -49,7 +49,6 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -60,7 +59,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_and_fixup_power_levels_contents -from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin @@ -1472,7 +1470,12 @@ class TimestampLookupHandler: Raises: SynapseError if unable to find any event locally in the given direction """ - + logger.debug( + "get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...", + room_id, + timestamp, + direction, + ) local_event_id = await self.store.get_event_id_for_timestamp( room_id, timestamp, direction ) @@ -1524,85 +1527,54 @@ class TimestampLookupHandler: ) ) - # Loop through each homeserver candidate until we get a succesful response - for domain in likely_domains: - # We don't want to ask our own server for information we don't have - if domain == self.server_name: - continue + remote_response = await self.federation_client.timestamp_to_event( + destinations=likely_domains, + room_id=room_id, + timestamp=timestamp, + direction=direction, + ) + if remote_response is not None: + logger.debug( + "get_event_for_timestamp: remote_response=%s", + remote_response, + ) - try: - remote_response = await self.federation_client.timestamp_to_event( - domain, room_id, timestamp, direction - ) - logger.debug( - "get_event_for_timestamp: response from domain(%s)=%s", - domain, - remote_response, - ) + remote_event_id = remote_response.event_id + remote_origin_server_ts = remote_response.origin_server_ts - remote_event_id = remote_response.event_id - remote_origin_server_ts = remote_response.origin_server_ts - - # Backfill this event so we can get a pagination token for - # it with `/context` and paginate `/messages` from this - # point. - # - # TODO: The requested timestamp may lie in a part of the - # event graph that the remote server *also* didn't have, - # in which case they will have returned another event - # which may be nowhere near the requested timestamp. In - # the future, we may need to reconcile that gap and ask - # other homeservers, and/or extend `/timestamp_to_event` - # to return events on *both* sides of the timestamp to - # help reconcile the gap faster. - remote_event = ( - await self.federation_event_handler.backfill_event_id( - domain, room_id, remote_event_id - ) - ) + # Backfill this event so we can get a pagination token for + # it with `/context` and paginate `/messages` from this + # point. + pulled_pdu_info = await self.federation_event_handler.backfill_event_id( + likely_domains, room_id, remote_event_id + ) + remote_event = pulled_pdu_info.pdu - # XXX: When we see that the remote server is not trustworthy, - # maybe we should not ask them first in the future. - if remote_origin_server_ts != remote_event.origin_server_ts: - logger.info( - "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", - domain, - remote_event_id, - remote_origin_server_ts, - remote_event.origin_server_ts, - ) - - # Only return the remote event if it's closer than the local event - if not local_event or ( - abs(remote_event.origin_server_ts - timestamp) - < abs(local_event.origin_server_ts - timestamp) - ): - logger.info( - "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", - remote_event_id, - remote_event.origin_server_ts, - timestamp, - local_event.event_id if local_event else None, - local_event.origin_server_ts if local_event else None, - ) - return remote_event_id, remote_origin_server_ts - except (HttpResponseException, InvalidResponseError) as ex: - # Let's not put a high priority on some other homeserver - # failing to respond or giving a random response - logger.debug( - "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", - domain, - type(ex).__name__, - ex, - ex.args, + # XXX: When we see that the remote server is not trustworthy, + # maybe we should not ask them first in the future. + if remote_origin_server_ts != remote_event.origin_server_ts: + logger.info( + "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", + pulled_pdu_info.pull_origin, + remote_event_id, + remote_origin_server_ts, + remote_event.origin_server_ts, ) - except Exception: - # But we do want to see some exceptions in our code - logger.warning( - "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception", - domain, - exc_info=True, + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(remote_event.origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + logger.info( + "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", + remote_event_id, + remote_event.origin_server_ts, + timestamp, + local_event.event_id if local_event else None, + local_event.origin_server_ts if local_event else None, ) + return remote_event_id, remote_origin_server_ts # To appease mypy, we have to add both of these conditions to check for # `None`. We only expect `local_event` to be `None` when diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index d0a69ff843..dcc037b982 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -51,7 +51,7 @@ class NotRetryingDestination(Exception): destination: the domain in question """ - msg = "Not retrying server %s." % (destination,) + msg = f"Not retrying server {destination} because we tried it recently retry_last_ts={retry_last_ts} and we won't check for another retry_interval={retry_interval}ms." super().__init__(msg) self.retry_last_ts = retry_last_ts diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 51d3bb8fff..e67f405826 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -142,14 +142,14 @@ class FederationClientTest(FederatingHomeserverTestCase): def test_get_pdu_returns_nothing_when_event_does_not_exist(self): """No event should be returned when the event does not exist""" - remote_pdu = self.get_success( + pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], "event_should_not_exist", RoomVersions.V9, ) ) - self.assertEqual(remote_pdu, None) + self.assertEqual(pulled_pdu_info, None) def test_get_pdu(self): """Test to make sure an event is returned by `get_pdu()`""" @@ -169,13 +169,15 @@ class FederationClientTest(FederatingHomeserverTestCase): remote_pdu.internal_metadata.outlier = True # Get the event again. This time it should read it from cache. - remote_pdu2 = self.get_success( + pulled_pdu_info2 = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], remote_pdu.event_id, RoomVersions.V9, ) ) + self.assertIsNotNone(pulled_pdu_info2) + remote_pdu2 = pulled_pdu_info2.pdu # Sanity check that we are working against the same event self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id) @@ -215,13 +217,15 @@ class FederationClientTest(FederatingHomeserverTestCase): ) ) - remote_pdu = self.get_success( + pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], "event_id", RoomVersions.V9, ) ) + self.assertIsNotNone(pulled_pdu_info) + remote_pdu = pulled_pdu_info.pdu # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( -- cgit 1.5.1 From 67583281e3f8ea923eedbc56a4c85c7ba75d1582 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Oct 2022 09:58:12 -0400 Subject: Fix tests for change in PostgreSQL 14 behavior change. (#14310) PostgreSQL 14 changed the behavior of `websearch_to_tsquery` to improve some behaviour. The tests were hitting those edge-cases about handling of hanging double quotes. This fixes the tests to take into account the PostgreSQL version. --- changelog.d/14310.feature | 1 + synapse/storage/databases/main/search.py | 5 ++--- tests/storage/test_room_search.py | 16 ++++++++++++---- 3 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14310.feature (limited to 'tests') diff --git a/changelog.d/14310.feature b/changelog.d/14310.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/14310.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index a89fc54c2c..594b935614 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -824,9 +824,8 @@ def _tokenize_query(query: str) -> TokenList: in_phrase = False parts = deque(query.split('"')) for i, part in enumerate(parts): - # The contents inside double quotes is treated as a phrase, a trailing - # double quote is not implied. - in_phrase = bool(i % 2) and i != (len(parts) - 1) + # The contents inside double quotes is treated as a phrase. + in_phrase = bool(i % 2) # Pull out the individual words, discarding any non-word characters. words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 9ddc19900a..868b5bee84 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -239,7 +239,6 @@ class MessageSearchTest(HomeserverTestCase): ("fox -nope", (True, False)), ("fox -brown", (False, True)), ('"fox" quick', True), - ('"fox quick', True), ('"quick brown', True), ('" quick "', True), ('" nope"', False), @@ -269,6 +268,15 @@ class MessageSearchTest(HomeserverTestCase): response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) self.assertIn("event_id", response) + # The behaviour of a missing trailing double quote changed in PostgreSQL 14 + # from ignoring the initial double quote to treating it as a phrase. + main_store = homeserver.get_datastores().main + found = False + if isinstance(main_store.database_engine, PostgresEngine): + assert main_store.database_engine._version is not None + found = main_store.database_engine._version < 140000 + self.COMMON_CASES.append(('"fox quick', (found, True))) + def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" cases = ( @@ -280,9 +288,9 @@ class MessageSearchTest(HomeserverTestCase): ("fox -brown", ["fox", SearchToken.Not, "brown"]), ("- fox", [SearchToken.Not, "fox"]), ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), - # No trailing double quoe. - ('"fox quick', ["fox", SearchToken.And, "quick"]), - ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + # No trailing double quote. + ('"fox quick', [Phrase(["fox", "quick"])]), + ('"-fox quick', [Phrase(["-fox", "quick"])]), ('" quick "', [Phrase(["quick"])]), ( 'q"uick brow"n', -- cgit 1.5.1 From aa70556699e649f46f51a198fb104eecdc0d311b Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 27 Oct 2022 13:29:23 -0500 Subject: Check appservice user interest against the local users instead of all users (`get_users_in_room` mis-use) (#13958) --- changelog.d/13958.bugfix | 1 + docs/upgrade.md | 19 ++++ synapse/appservice/__init__.py | 16 ++- synapse/storage/databases/main/appservice.py | 17 ++- synapse/storage/databases/main/roommember.py | 3 + tests/appservice/test_appservice.py | 10 +- tests/handlers/test_appservice.py | 162 ++++++++++++++++++++++++++- 7 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13958.bugfix (limited to 'tests') diff --git a/changelog.d/13958.bugfix b/changelog.d/13958.bugfix new file mode 100644 index 0000000000..f9f651bfdc --- /dev/null +++ b/changelog.d/13958.bugfix @@ -0,0 +1 @@ +Check appservice user interest against the local users instead of all users in the room to align with [MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905). diff --git a/docs/upgrade.md b/docs/upgrade.md index 78c34d0c15..f095bbc3a6 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -97,6 +97,25 @@ As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_s Modules relying on it can instead use the `create_login_token` method. +## Changes to the events received by application services (interest) + +To align with spec (changed in +[MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905)), Synapse now +only considers local users to be interesting. In other words, the `users` namespace +regex is only be applied against local users of the homeserver. + +Please note, this probably doesn't affect the expected behavior of your application +service, since an interesting local user in a room still means all messages in the room +(from local or remote users) will still be considered interesting. And matching a room +with the `rooms` or `aliases` namespace regex will still consider all events sent in the +room to be interesting to the application service. + +If one of your application service's `users` regex was intending to match a remote user, +this will no longer match as you expect. The behavioral mismatch between matching all +local users and some remote users is why the spec was changed/clarified and this +caveat is no longer supported. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 0dfa00df44..500bdde3a9 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -172,12 +172,24 @@ class ApplicationService: Returns: True if this service would like to know about this room. """ - member_list = await store.get_users_in_room( + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + # + # In the future, we can consider re-using + # `store.get_app_service_users_in_room` which is very similar to this + # function but has a slightly worse performance than this because we + # have an early escape-hatch if we find a single user that the + # appservice is interested in. The juice would be worth the squeeze if + # `store.get_app_service_users_in_room` was used in more places besides + # an experimental MSC. But for now we can avoid doing more work and + # barely using it later. + local_user_ids = await store.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) # check joined member events - for user_id in member_list: + for user_id in local_user_ids: if self.is_interested_in_user(user_id): return True return False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 64b70a7b28..63046c0527 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): app_service: "ApplicationService", cache_context: _CacheContext, ) -> List[str]: - users_in_room = await self.get_users_in_room( + """ + Get all users in a room that the appservice controls. + + Args: + room_id: The room to check in. + app_service: The application service to check interest/control against + + Returns: + List of user IDs that the appservice controls. + """ + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + local_users_in_room = await self.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) - return list(filter(app_service.is_interested_in_user, users_in_room)) + return list(filter(app_service.is_interested_in_user, local_users_in_room)) class ApplicationServiceStore(ApplicationServiceWorkerStore): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ab708b0ba5..e56a13f21e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): the forward extremities of those rooms will exclude most members. We may also calculate room state incorrectly for such rooms and believe that a member is or is not in the room when the opposite is true. + + Note: If you only care about users in the room local to the homeserver, use + `get_local_users_in_room(...)` instead which will be more performant. """ return await self.db_pool.simple_select_onecol( table="current_state_events", diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 3018d3fc6f..d4dccfc2f0 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store = Mock() self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): @@ -129,7 +129,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -184,7 +184,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertFalse( ( yield defer.ensureDeferred( @@ -203,7 +203,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -236,7 +236,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_users_in_room = simple_async_mock( + self.store.get_local_users_in_room = simple_async_mock( ["@alice:here", "@irc_fo:here", "@bob:here"] ) self.store.get_aliases_for_room = simple_async_mock([]) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7e4570f990..144e49d0fd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.api.constants import EduTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, TransactionOneTimeKeyCounts, @@ -36,7 +36,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import make_awaitable, simple_async_mock +from tests.test_utils import event_injection, make_awaitable, simple_async_mock from tests.unittest import override_config from tests.utils import MockClock @@ -390,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) @@ -416,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def _notify_interested_services(self): + # This is normally set in `notify_interested_services` but we need to call the + # internal async version so the reactor gets pushed to completion. + self.hs.get_application_service_handler().current_max += 1 + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services( + RoomStreamToken( + None, self.hs.get_application_service_handler().current_max + ) + ) + ) + + @parameterized.expand( + [ + ("@local_as_user:test", True), + # Defining remote users in an application service user namespace regex is a + # footgun since the appservice might assume that it'll receive all events + # sent by that remote user, but it will only receive events in rooms that + # are shared with a local user. So we just remove this footgun possibility + # entirely and we won't notify the application service based on remote + # users. + ("@remote_as_user:remote", False), + ] + ) + def test_match_interesting_room_members( + self, interesting_user: str, should_notify: bool + ): + """ + Test to make sure that a interesting user (local or remote) in the room is + notified as expected when someone else in the room sends a message. + """ + # Register an application service that's interested in the `interesting_user` + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": interesting_user, + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # Join the interesting user to the room + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, interesting_user, "join" + ) + ) + # Kick the appservice into checking this membership event to get the event out + # of the way + self._notify_interested_services() + # We don't care about the interesting user join event (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from an uninteresting user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from uninteresting user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + if should_notify: + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Even though the message came from an uninteresting user, it should still + # notify us because the interesting user is joined to the room where the + # message was sent. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + else: + self.send_mock.assert_not_called() + + def test_application_services_receive_events_sent_by_interesting_local_user(self): + """ + Test to make sure that a messages sent from a local user can be interesting and + picked up by the appservice. + """ + # Register an application service that's interested in all local users + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": ".*", + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # We don't care about interesting events before this (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from the interesting local user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from interesting local user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Events sent from an interesting local user should also be picked up as + # interesting to the appservice. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + def test_sending_read_receipt_batches_to_application_services(self): """Tests that a large batch of read receipts are sent correctly to interested application services. -- cgit 1.5.1 From 6a6e1e8c0711939338f25d8d41d1e4d33d984949 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 28 Oct 2022 10:53:34 +0000 Subject: Fix room creation being rate limited too aggressively since Synapse v1.69.0. (#14314) * Introduce a test for the old behaviour which we want to restore * Reintroduce the old behaviour in a simpler way * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Use 1 credit instead of 2 for creating a room: be more lenient than before Notably, the UI in Element Web was still broken after restoring to prior behaviour. After discussion, we agreed that it would be sensible to increase the limit. Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/14314.bugfix | 1 + synapse/api/ratelimiting.py | 8 +++++- synapse/handlers/room.py | 16 ++++++++---- tests/rest/client/test_rooms.py | 54 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14314.bugfix (limited to 'tests') diff --git a/changelog.d/14314.bugfix b/changelog.d/14314.bugfix new file mode 100644 index 0000000000..8be47ee083 --- /dev/null +++ b/changelog.d/14314.bugfix @@ -0,0 +1 @@ +Fix room creation being rate limited too aggressively since Synapse v1.69.0. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 044c7d4926..511790c7c5 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -343,6 +343,7 @@ class RequestRatelimiter: requester: Requester, update: bool = True, is_admin_redaction: bool = False, + n_actions: int = 1, ) -> None: """Ratelimits requests. @@ -355,6 +356,8 @@ class RequestRatelimiter: is_admin_redaction: Whether this is a room admin/moderator redacting an event. If so then we may apply different ratelimits depending on config. + n_actions: Multiplier for the number of actions to apply to the + rate limiter at once. Raises: LimitExceededError if the request should be ratelimited @@ -383,7 +386,9 @@ class RequestRatelimiter: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + await self.admin_redaction_ratelimiter.ratelimit( + requester, update=update, n_actions=n_actions + ) else: # Override rate and burst count per-user await self.request_ratelimiter.ratelimit( @@ -391,4 +396,5 @@ class RequestRatelimiter: rate_hz=messages_per_second, burst_count=burst_count, update=update, + n_actions=n_actions, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..d74b675adc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -559,7 +559,6 @@ class RoomCreationHandler: invite_list=[], initial_state=initial_state, creation_content=creation_content, - ratelimit=False, ) # Transfer membership events @@ -753,6 +752,10 @@ class RoomCreationHandler: ) if ratelimit: + # Rate limit once in advance, but don't rate limit the individual + # events in the room — room creation isn't atomic and it's very + # janky if half the events in the initial state don't make it because + # of rate limiting. await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( @@ -913,7 +916,6 @@ class RoomCreationHandler: room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, - ratelimit=ratelimit, ) if "name" in config: @@ -1037,7 +1039,6 @@ class RoomCreationHandler: room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, - ratelimit: bool = True, ) -> Tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1046,6 +1047,8 @@ class RoomCreationHandler: `power_level_content_override` doesn't apply when initial state has power level state event content. + Rate limiting should already have been applied by this point. + Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. @@ -1144,7 +1147,7 @@ class RoomCreationHandler: creator.user, room_id, "join", - ratelimit=ratelimit, + ratelimit=False, content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], @@ -1269,7 +1272,10 @@ class RoomCreationHandler: events_to_send.append((encryption_event, encryption_context)) last_event = await self.event_creation_handler.handle_new_client_event( - creator, events_to_send, ignore_shadow_ban=True + creator, + events_to_send, + ignore_shadow_ban=True, + ratelimit=False, ) assert last_event.internal_metadata.stream_ordering is not None return last_event.internal_metadata.stream_ordering, last_event.event_id, depth diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 716366eb90..1084d4ad9d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -54,6 +54,7 @@ from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event +from tests.unittest import override_config PATH_PREFIX = b"/_matrix/client/api/v1" @@ -871,6 +872,41 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def _create_basic_room(self) -> Tuple[int, object]: + """ + Tries to create a basic room and returns the response code. + """ + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + return channel.code, channel.json_body + + @override_config( + { + "rc_message": {"per_second": 0.2, "burst_count": 10}, + } + ) + def test_room_creation_ratelimiting(self) -> None: + """ + Regression test for #14312, where ratelimiting was made too strict. + Clients should be able to create 10 rooms in a row + without hitting rate limits, using default rate limit config. + (We override rate limiting config back to its default value.) + + To ensure we don't make ratelimiting too generous accidentally, + also check that we can't create an 11th room. + """ + + for _ in range(10): + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.OK, json_body) + + # The 6th room hits the rate limit. + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -1390,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + + joiner_user_id = self.register_user("joiner", "secret") + # Now make a new user try to join some of them. - self.helper.create_room_as(self.user_id, expect_code=429) + # The user can join 3 rooms + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id) + + # But the user cannot join a 4th room + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} -- cgit 1.5.1 From 2bb2c32e8ed5642a5bf3ba1e8c49e10cecc88905 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 13:02:07 +0000 Subject: Avoid incrementing bg process utime/stime counters by negative durations (#14323) --- changelog.d/14323.bugfix | 1 + mypy.ini | 4 +- synapse/metrics/background_process_metrics.py | 6 +- tests/metrics/__init__.py | 0 tests/metrics/test_background_process_metrics.py | 19 +++ tests/metrics/test_metrics.py | 206 +++++++++++++++++++++++ tests/test_metrics.py | 200 ---------------------- 7 files changed, 233 insertions(+), 203 deletions(-) create mode 100644 changelog.d/14323.bugfix create mode 100644 tests/metrics/__init__.py create mode 100644 tests/metrics/test_background_process_metrics.py create mode 100644 tests/metrics/test_metrics.py delete mode 100644 tests/test_metrics.py (limited to 'tests') diff --git a/changelog.d/14323.bugfix b/changelog.d/14323.bugfix new file mode 100644 index 0000000000..da39bc020c --- /dev/null +++ b/changelog.d/14323.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 0.34.0rc2 where logs could include error spam when background processes are measured as taking a negative amount of time. diff --git a/mypy.ini b/mypy.ini index 34b4523e00..8f1141a239 100644 --- a/mypy.ini +++ b/mypy.ini @@ -56,7 +56,6 @@ exclude = (?x) |tests/rest/media/v1/test_media_storage.py |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py - |tests/test_metrics.py |tests/test_state.py |tests/test_terms_auth.py |tests/util/caches/test_cached_call.py @@ -106,6 +105,9 @@ disallow_untyped_defs = False [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.metrics.test_background_process_metrics] +disallow_untyped_defs = True + [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 7a1516d3a8..9ea4e23b31 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -174,8 +174,10 @@ class _BackgroundProcess: diff = new_stats - self._reported_stats self._reported_stats = new_stats - _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) - _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) + # For unknown reasons, the difference in times can be negative. See comment in + # synapse.http.request_metrics.RequestMetrics.update_metrics. + _background_process_ru_utime.labels(self.desc).inc(max(diff.ru_utime, 0)) + _background_process_ru_stime.labels(self.desc).inc(max(diff.ru_stime, 0)) _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) _background_process_db_txn_duration.labels(self.desc).inc( diff.db_txn_duration_sec diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metrics/test_background_process_metrics.py b/tests/metrics/test_background_process_metrics.py new file mode 100644 index 0000000000..f0f6cb2912 --- /dev/null +++ b/tests/metrics/test_background_process_metrics.py @@ -0,0 +1,19 @@ +from unittest import TestCase as StdlibTestCase +from unittest.mock import Mock + +from synapse.logging.context import ContextResourceUsage, LoggingContext +from synapse.metrics.background_process_metrics import _BackgroundProcess + + +class TestBackgroundProcessMetrics(StdlibTestCase): + def test_update_metrics_with_negative_time_diff(self) -> None: + """We should ignore negative reported utime and stime differences""" + usage = ContextResourceUsage() + usage.ru_stime = usage.ru_utime = -1.0 + + mock_logging_context = Mock(spec=LoggingContext) + mock_logging_context.get_resource_usage.return_value = usage + + process = _BackgroundProcess("test process", mock_logging_context) + # Should not raise + process.update_metrics() diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 0000000000..bddc4228bc --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,206 @@ +# Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing_extensions import Protocol + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + +from unittest.mock import patch + +from pkg_resources import parse_version + +from synapse.app._base import _set_prometheus_client_use_created_metrics +from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.deferred_cache import DeferredCache + +from tests import unittest + + +def get_sample_labels_value(sample): + """Extract the labels and values of a sample. + + prometheus_client 0.5 changed the sample type to a named tuple with more + members than the plain tuple had in 0.4 and earlier. This function can + extract the labels and value from the sample for both sample types. + + Args: + sample: The sample to get the labels and value from. + Returns: + A tuple of (labels, value) from the sample. + """ + + # If the sample has a labels and value attribute, use those. + if hasattr(sample, "labels") and hasattr(sample, "value"): + return sample.labels, sample.value + # Otherwise fall back to treating it as a plain 3 tuple. + else: + _, labels, value = sample + return labels, value + + +class TestMauLimit(unittest.TestCase): + def test_basic(self): + class MetricEntry(Protocol): + foo: int + bar: int + + gauge: InFlightGauge[MetricEntry] = InFlightGauge( + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] + ) + + def handle1(metrics): + metrics.foo += 2 + metrics.bar = max(metrics.bar, 5) + + def handle2(metrics): + metrics.foo += 3 + metrics.bar = max(metrics.bar, 7) + + gauge.register(("key1",), handle1) + + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.unregister(("key1",), handle1) + + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.register(("key1",), handle1) + gauge.register(("key2",), handle2) + + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.unregister(("key2",), handle2) + gauge.register(("key1",), handle2) + + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) + + def get_metrics_from_gauge(self, gauge): + results = {} + + for r in gauge.collect(): + results[r.name] = { + tuple(labels[x] for x in gauge.labels): value + for labels, value in map(get_sample_labels_value, r.samples) + } + + return results + + +class BuildInfoTests(unittest.TestCase): + def test_get_build(self): + """ + The synapse_build_info metric reports the OS version, Python version, + and Synapse version. + """ + items = list( + filter( + lambda x: b"synapse_build_info{" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + ) + self.assertEqual(len(items), 1) + self.assertTrue(b"osversion=" in items[0]) + self.assertTrue(b"pythonversion=" in items[0]) + self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache: DeferredCache[str, str] = DeferredCache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + +class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase): + if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"): + skip = "prometheus-client too old" + + def test_created_metrics_disabled(self) -> None: + """ + Tests that a brittle hack, to disable `_created` metrics, works. + This involves poking at the internals of prometheus-client. + It's not the end of the world if this doesn't work. + + This test gives us a way to notice if prometheus-client changes + their internals. + """ + import prometheus_client.metrics + + PRIVATE_FLAG_NAME = "_use_created" + + # By default, the pesky `_created` metrics are enabled. + # Check this assumption is still valid. + self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME)) + + with patch("prometheus_client.metrics") as mock: + setattr(mock, PRIVATE_FLAG_NAME, True) + _set_prometheus_client_use_created_metrics(False) + self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False)) diff --git a/tests/test_metrics.py b/tests/test_metrics.py deleted file mode 100644 index 1a70eddc9b..0000000000 --- a/tests/test_metrics.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2018 New Vector Ltd -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -try: - from importlib import metadata -except ImportError: - import importlib_metadata as metadata # type: ignore[no-redef] - -from unittest.mock import patch - -from pkg_resources import parse_version - -from synapse.app._base import _set_prometheus_client_use_created_metrics -from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.deferred_cache import DeferredCache - -from tests import unittest - - -def get_sample_labels_value(sample): - """Extract the labels and values of a sample. - - prometheus_client 0.5 changed the sample type to a named tuple with more - members than the plain tuple had in 0.4 and earlier. This function can - extract the labels and value from the sample for both sample types. - - Args: - sample: The sample to get the labels and value from. - Returns: - A tuple of (labels, value) from the sample. - """ - - # If the sample has a labels and value attribute, use those. - if hasattr(sample, "labels") and hasattr(sample, "value"): - return sample.labels, sample.value - # Otherwise fall back to treating it as a plain 3 tuple. - else: - _, labels, value = sample - return labels, value - - -class TestMauLimit(unittest.TestCase): - def test_basic(self): - gauge = InFlightGauge( - "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] - ) - - def handle1(metrics): - metrics.foo += 2 - metrics.bar = max(metrics.bar, 5) - - def handle2(metrics): - metrics.foo += 3 - metrics.bar = max(metrics.bar, 7) - - gauge.register(("key1",), handle1) - - self.assert_dict( - { - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.unregister(("key1",), handle1) - - self.assert_dict( - { - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.register(("key1",), handle1) - gauge.register(("key2",), handle2) - - self.assert_dict( - { - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.unregister(("key2",), handle2) - gauge.register(("key1",), handle2) - - self.assert_dict( - { - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, - self.get_metrics_from_gauge(gauge), - ) - - def get_metrics_from_gauge(self, gauge): - results = {} - - for r in gauge.collect(): - results[r.name] = { - tuple(labels[x] for x in gauge.labels): value - for labels, value in map(get_sample_labels_value, r.samples) - } - - return results - - -class BuildInfoTests(unittest.TestCase): - def test_get_build(self): - """ - The synapse_build_info metric reports the OS version, Python version, - and Synapse version. - """ - items = list( - filter( - lambda x: b"synapse_build_info{" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - ) - self.assertEqual(len(items), 1) - self.assertTrue(b"osversion=" in items[0]) - self.assertTrue(b"pythonversion=" in items[0]) - self.assertTrue(b"version=" in items[0]) - - -class CacheMetricsTests(unittest.HomeserverTestCase): - def test_cache_metric(self): - """ - Caches produce metrics reflecting their state when scraped. - """ - CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = DeferredCache(CACHE_NAME, max_entries=777) - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - cache.prefill("1", "hi") - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - -class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase): - if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"): - skip = "prometheus-client too old" - - def test_created_metrics_disabled(self) -> None: - """ - Tests that a brittle hack, to disable `_created` metrics, works. - This involves poking at the internals of prometheus-client. - It's not the end of the world if this doesn't work. - - This test gives us a way to notice if prometheus-client changes - their internals. - """ - import prometheus_client.metrics - - PRIVATE_FLAG_NAME = "_use_created" - - # By default, the pesky `_created` metrics are enabled. - # Check this assumption is still valid. - self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME)) - - with patch("prometheus_client.metrics") as mock: - setattr(mock, PRIVATE_FLAG_NAME, True) - _set_prometheus_client_use_created_metrics(False) - self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False)) -- cgit 1.5.1 From cc3a52b33df72bb4230367536b924a6d1f510d36 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 31 Oct 2022 18:07:30 +0100 Subject: Support OIDC backchannel logouts (#11414) If configured an OIDC IdP can log a user's session out of Synapse when they log out of the identity provider. The IdP sends a request directly to Synapse (and must be configured with an endpoint) when a user logs out. --- changelog.d/11414.feature | 1 + docs/openid.md | 14 + docs/usage/configuration/config_documentation.md | 9 + synapse/config/oidc.py | 12 + synapse/handlers/oidc.py | 381 ++++++++++++++++++-- synapse/handlers/sso.py | 71 ++++ synapse/rest/synapse/client/oidc/__init__.py | 4 + .../client/oidc/backchannel_logout_resource.py | 35 ++ synapse/storage/databases/main/registration.py | 21 ++ tests/rest/client/test_auth.py | 390 +++++++++++++++++++-- tests/rest/client/utils.py | 55 ++- tests/server.py | 6 + tests/test_utils/oidc.py | 27 +- 13 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 changelog.d/11414.feature create mode 100644 synapse/rest/synapse/client/oidc/backchannel_logout_resource.py (limited to 'tests') diff --git a/changelog.d/11414.feature b/changelog.d/11414.feature new file mode 100644 index 0000000000..fc035e50a7 --- /dev/null +++ b/changelog.d/11414.feature @@ -0,0 +1 @@ +Support back-channel logouts from OpenID Connect providers. diff --git a/docs/openid.md b/docs/openid.md index 87ebea4c29..37c5eb244d 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -49,6 +49,13 @@ setting in your configuration file. See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as the text below for example configurations for specific providers. +## OIDC Back-Channel Logout + +Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications. + +This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session. +This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` + ## Sample configs Here are a few configs for providers that should work with Synapse. @@ -123,6 +130,9 @@ oidc_providers: [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. +Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak. +This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak. + Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. 1. Click `Clients` in the sidebar and click `Create` @@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Client Protocol | `openid-connect` | | Access Type | `confidential` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | +| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` | +| Backchannel Logout Session Required (optional) | `On` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -167,7 +179,9 @@ oidc_providers: config: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" + backchannel_logout_enabled: true # Optional ``` + ### Auth0 [Auth0][auth0] is a hosted SaaS IdP solution. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 97fb505a5f..44358faf59 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3021,6 +3021,15 @@ Options for each entry include: which is set to the claims returned by the UserInfo Endpoint and/or in the ID Token. +* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications. + Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`. + Defaults to `false`. + +* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the + `sub` claim matches the subject claim received during login. This check can be disabled by setting + this to `true`. Defaults to `false`. + + You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`. It is possible to configure Synapse to only allow logins if certain attributes match particular values in the OIDC userinfo. The requirements can be listed under diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 5418a332da..0bd83f4010 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "userinfo_endpoint": {"type": "string"}, "jwks_uri": {"type": "string"}, "skip_verification": {"type": "boolean"}, + "backchannel_logout_enabled": {"type": "boolean"}, + "backchannel_logout_ignore_sub": {"type": "boolean"}, "user_profile_method": { "type": "string", "enum": ["auto", "userinfo_endpoint"], @@ -292,6 +294,10 @@ def _parse_oidc_config_dict( token_endpoint=oidc_config.get("token_endpoint"), userinfo_endpoint=oidc_config.get("userinfo_endpoint"), jwks_uri=oidc_config.get("jwks_uri"), + backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False), + backchannel_logout_ignore_sub=oidc_config.get( + "backchannel_logout_ignore_sub", False + ), skip_verification=oidc_config.get("skip_verification", False), user_profile_method=oidc_config.get("user_profile_method", "auto"), allow_existing_users=oidc_config.get("allow_existing_users", False), @@ -368,6 +374,12 @@ class OidcProviderConfig: # "openid" scope is used. jwks_uri: Optional[str] + # Whether Synapse should react to backchannel logouts + backchannel_logout_enabled: bool + + # Whether Synapse should ignore the `sub` claim in backchannel logouts or not. + backchannel_logout_ignore_sub: bool + # Whether to skip metadata verification skip_verification: bool diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 9759daf043..867973dcca 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import binascii import inspect +import json import logging -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import urlencode, urlparse import attr +import unpaddedbase64 from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, jwt +from authlib.jose import JsonWebToken, JWTClaims +from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, UserInfo @@ -35,9 +49,12 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.server import finish_request +from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -88,6 +105,8 @@ class Token(TypedDict): #: there is no real point of doing this in our case. JWK = Dict[str, str] +C = TypeVar("C") + #: A JWK Set, as per RFC7517 sec 5. class JWKS(TypedDict): @@ -247,6 +266,80 @@ class OidcHandler: await oidc_provider.handle_oidc_callback(request, session_data, code) + async def handle_backchannel_logout(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + This extracts the logout_token from the request and tries to figure out + which OpenID Provider it is comming from. This works by matching the iss claim + with the issuer and the aud claim with the client_id. + + Since at this point we don't know who signed the JWT, we can't just + decode it using authlib since it will always verifies the signature. We + have to decode it manually without validating the signature. The actual JWT + verification is done in the `OidcProvider.handler_backchannel_logout` method, + once we figured out which provider sent the request. + + Args: + request: the incoming request from the browser. + """ + logout_token = parse_string(request, "logout_token") + if logout_token is None: + raise SynapseError(400, "Missing logout_token in request") + + # A JWT looks like this: + # header.payload.signature + # where all parts are encoded with urlsafe base64. + # The aud and iss claims we care about are in the payload part, which + # is a JSON object. + try: + # By destructuring the list after splitting, we ensure that we have + # exactly 3 segments + _, payload, _ = logout_token.split(".") + except ValueError: + raise SynapseError(400, "Invalid logout_token in request") + + try: + payload_bytes = unpaddedbase64.decode_base64(payload) + claims = json_decoder.decode(payload_bytes.decode("utf-8")) + except (json.JSONDecodeError, binascii.Error, UnicodeError): + raise SynapseError(400, "Invalid logout_token payload in request") + + try: + # Let's extract the iss and aud claims + iss = claims["iss"] + aud = claims["aud"] + # The aud claim can be either a string or a list of string. Here we + # normalize it as a list of strings. + if isinstance(aud, str): + aud = [aud] + + # Check that we have the right types for the aud and the iss claims + if not isinstance(iss, str) or not isinstance(aud, list): + raise TypeError() + for a in aud: + if not isinstance(a, str): + raise TypeError() + + # At this point we properly checked both claims types + issuer: str = iss + audience: List[str] = aud + except (TypeError, KeyError): + raise SynapseError(400, "Invalid issuer/audience in logout_token") + + # Now that we know the audience and the issuer, we can figure out from + # what provider it is coming from + oidc_provider: Optional[OidcProvider] = None + for provider in self._providers.values(): + if provider.issuer == issuer and provider.client_id in audience: + oidc_provider = provider + break + + if oidc_provider is None: + raise SynapseError(400, "Could not find the OP that issued this event") + + # Ask the provider to handle the logout request. + await oidc_provider.handle_backchannel_logout(request, logout_token) + class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" @@ -342,6 +435,7 @@ class OidcProvider: self.idp_brand = provider.idp_brand self._sso_handler = hs.get_sso_handler() + self._device_handler = hs.get_device_handler() self._sso_handler.register_identity_provider(self) @@ -400,6 +494,41 @@ class OidcProvider: # If we're not using userinfo, we need a valid jwks to validate the ID token m.validate_jwks_uri() + if self._config.backchannel_logout_enabled: + if not m.get("backchannel_logout_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled for issuer %r" + "but it does not advertise support for it", + self.issuer, + ) + + elif not m.get("backchannel_logout_session_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled and supported " + "by issuer %r but it might not send a session ID with " + "logout tokens, which is required for the logouts to work", + self.issuer, + ) + + if not self._config.backchannel_logout_ignore_sub: + # If OIDC backchannel logouts are enabled, the provider mapping provider + # should use the `sub` claim. We verify that by mapping a dumb user and + # see if we get back the sub claim + user = UserInfo({"sub": "thisisasubject"}) + try: + subject = self._user_mapping_provider.get_remote_user_id(user) + if subject != user["sub"]: + raise ValueError("Unexpected subject") + except Exception: + logger.warning( + f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} " + "but it looks like the configured `user_mapping_provider` " + "does not use the `sub` claim as subject. If it is the case, " + "and you want Synapse to ignore the `sub` claim in OIDC " + "Back-Channel Logouts, set `backchannel_logout_ignore_sub` " + "to `true` in the issuer config." + ) + @property def _uses_userinfo(self) -> bool: """Returns True if the ``userinfo_endpoint`` should be used. @@ -415,6 +544,16 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) + @property + def issuer(self) -> str: + """The issuer identifying this provider.""" + return self._config.issuer + + @property + def client_id(self) -> str: + """The client_id used when interacting with this provider.""" + return self._config.client_id + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: """Return the provider metadata. @@ -662,6 +801,59 @@ class OidcProvider: return UserInfo(resp) + async def _verify_jwt( + self, + alg_values: List[str], + token: str, + claims_cls: Type[C], + claims_options: Optional[dict] = None, + claims_params: Optional[dict] = None, + ) -> C: + """Decode and validate a JWT, re-fetching the JWKS as needed. + + Args: + alg_values: list of `alg` values allowed when verifying the JWT. + token: the JWT. + claims_cls: the JWTClaims class to use to validate the claims. + claims_options: dict of options passed to the `claims_cls` constructor. + claims_params: dict of params passed to the `claims_cls` constructor. + + Returns: + The decoded claims in the JWT. + """ + jwt = JsonWebToken(alg_values) + + logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token) + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + + logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) + + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew + return claims + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. @@ -675,13 +867,13 @@ class OidcProvider: The decoded claims in the ID token. """ id_token = token.get("id_token") - logger.debug("Attempting to decode JWT id_token %r", id_token) # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy assert id_token is not None metadata = await self.load_metadata() + claims_params = { "nonce": nonce, "client_id": self._client_auth.client_id, @@ -691,38 +883,17 @@ class OidcProvider: # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) - - claim_options = {"iss": {"values": [metadata["issuer"]]}} + claims_options = {"iss": {"values": [metadata["issuer"]]}} - # Try to decode the keys in cache first, then retry by forcing the keys - # to be reloaded - jwk_set = await self.load_jwks() - try: - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - except ValueError: - logger.info("Reloading JWKS after decode error") - jwk_set = await self.load_jwks(force=True) # try reloading the jwks - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - - logger.debug("Decoded id_token JWT %r; validating", claims) + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - claims.validate( - now=self._clock.time(), leeway=120 - ) # allows 2 min of clock skew + claims = await self._verify_jwt( + alg_values=alg_values, + token=id_token, + claims_cls=CodeIDToken, + claims_options=claims_options, + claims_params=claims_params, + ) return claims @@ -1043,6 +1214,146 @@ class OidcProvider: # to be strings. return str(remote_user_id) + async def handle_backchannel_logout( + self, request: SynapseRequest, logout_token: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + The OIDC Provider posts a logout token to this endpoint when a user + session ends. That token is a JWT signed with the same keys as + ID tokens. The OpenID Connect Back-Channel Logout draft explains how to + validate the JWT and figure out what session to end. + + Args: + request: The request to respond to + logout_token: The logout token (a JWT) extracted from the request body + """ + # Back-Channel Logout can be disabled in the config, hence this check. + # This is not that important for now since Synapse is registered + # manually to the OP, so not specifying the backchannel-logout URI is + # as effective than disabling it here. It might make more sense if we + # support dynamic registration in Synapse at some point. + if not self._config.backchannel_logout_enabled: + logger.warning( + f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config" + ) + + # TODO: this responds with a 400 status code, which is what the OIDC + # Back-Channel Logout spec expects, but spec also suggests answering with + # a JSON object, with the `error` and `error_description` fields set, which + # we are not doing here. + # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse + raise SynapseError( + 400, "OpenID Connect Back-Channel Logout is disabled for this provider" + ) + + metadata = await self.load_metadata() + + # As per OIDC Back-Channel Logout 1.0 sec. 2.4: + # A Logout Token MUST be signed and MAY also be encrypted. The same + # keys are used to sign and encrypt Logout Tokens as are used for ID + # Tokens. If the Logout Token is encrypted, it SHOULD replicate the + # iss (issuer) claim in the JWT Header Parameters, as specified in + # Section 5.3 of [JWT]. + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + # As per sec. 2.6: + # 3. Validate the iss, aud, and iat Claims in the same way they are + # validated in ID Tokens. + # Which means the audience should contain Synapse's client_id and the + # issuer should be the IdP issuer + claims_options = { + "iss": {"values": [metadata["issuer"]]}, + "aud": {"values": [self.client_id]}, + } + + try: + claims = await self._verify_jwt( + alg_values=alg_values, + token=logout_token, + claims_cls=LogoutToken, + claims_options=claims_options, + ) + except JoseError: + logger.exception("Invalid logout_token") + raise SynapseError(400, "Invalid logout_token") + + # As per sec. 2.6: + # 4. Verify that the Logout Token contains a sub Claim, a sid Claim, + # or both. + # 5. Verify that the Logout Token contains an events Claim whose + # value is JSON object containing the member name + # http://schemas.openid.net/event/backchannel-logout. + # 6. Verify that the Logout Token does not contain a nonce Claim. + # This is all verified by the LogoutToken claims class, so at this + # point the `sid` claim exists and is a string. + sid: str = claims.get("sid") + + # If the `sub` claim was included in the logout token, we check that it matches + # that it matches the right user. We can have cases where the `sub` claim is not + # the ID saved in database, so we let admins disable this check in config. + sub: Optional[str] = claims.get("sub") + expected_user_id: Optional[str] = None + if sub is not None and not self._config.backchannel_logout_ignore_sub: + expected_user_id = await self._store.get_user_by_external_id( + self.idp_id, sub + ) + + # Invalidate any running user-mapping sessions, in-flight login tokens and + # active devices + await self._sso_handler.revoke_sessions_for_provider_session_id( + auth_provider_id=self.idp_id, + auth_provider_session_id=sid, + expected_user_id=expected_user_id, + ) + + request.setResponseCode(200) + request.setHeader(b"Cache-Control", b"no-cache, no-store") + request.setHeader(b"Pragma", b"no-cache") + finish_request(request) + + +class LogoutToken(JWTClaims): + """ + Holds and verify claims of a logout token, as per + https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken + """ + + REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] + + def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + """Validate everything in claims payload.""" + super().validate(now, leeway) + self.validate_sid() + self.validate_events() + self.validate_nonce() + + def validate_sid(self) -> None: + """Ensure the sid claim is present""" + sid = self.get("sid") + if not sid: + raise MissingClaimError("sid") + + if not isinstance(sid, str): + raise InvalidClaimError("sid") + + def validate_nonce(self) -> None: + """Ensure the nonce claim is absent""" + if "nonce" in self: + raise InvalidClaimError("nonce") + + def validate_events(self) -> None: + """Ensure the events claim is present and with the right value""" + events = self.get("events") + if not events: + raise MissingClaimError("events") + + if not isinstance(events, dict): + raise InvalidClaimError("events") + + if "http://schemas.openid.net/event/backchannel-logout" not in events: + raise InvalidClaimError("events") + # number of seconds a newly-generated client secret should be valid for CLIENT_SECRET_VALIDITY_SECONDS = 3600 @@ -1112,6 +1423,7 @@ class JwtClientSecret: logger.info( "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload ) + jwt = JsonWebToken(header["alg"]) self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret_replacement_time = ( expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS @@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict): emails: List[str] -C = TypeVar("C") - - class OidcMappingProvider(Generic[C]): """A mapping provider maps a UserInfo object to user attributes. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 5943f08e91..749d7e93b0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -191,6 +191,7 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() @@ -1026,6 +1027,76 @@ class SsoHandler: return True + async def revoke_sessions_for_provider_session_id( + self, + auth_provider_id: str, + auth_provider_session_id: str, + expected_user_id: Optional[str] = None, + ) -> None: + """Revoke any devices and in-flight logins tied to a provider session. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + auth_provider_session_id: The session ID from the provider to logout + expected_user_id: The user we're expecting to logout. If set, it will ignore + sessions belonging to other users and log an error. + """ + # Invalidate any running user-mapping sessions + to_delete = [] + for session_id, session in self._username_mapping_sessions.items(): + if ( + session.auth_provider_id == auth_provider_id + and session.auth_provider_session_id == auth_provider_session_id + ): + to_delete.append(session_id) + + for session_id in to_delete: + logger.info("Revoking mapping session %s", session_id) + del self._username_mapping_sessions[session_id] + + # Invalidate any in-flight login tokens + await self._store.invalidate_login_tokens_by_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # Fetch any device(s) in the store associated with the session ID. + devices = await self._store.get_devices_by_auth_provider_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # We have no guarantee that all the devices of that session are for the same + # `user_id`. Hence, we have to iterate over the list of devices and log them out + # one by one. + for device in devices: + user_id = device["user_id"] + device_id = device["device_id"] + + # If the user_id associated with that device/session is not the one we got + # out of the `sub` claim, skip that device and show log an error. + if expected_user_id is not None and user_id != expected_user_id: + logger.error( + "Received a logout notification from SSO provider " + f"{auth_provider_id!r} for the user {expected_user_id!r}, but with " + f"a session ID ({auth_provider_session_id!r}) which belongs to " + f"{user_id!r}. This may happen when the SSO provider user mapper " + "uses something else than the standard attribute as mapping ID. " + "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` " + "in the provider config if that is the case." + ) + continue + + logger.info( + "Logging out %r (device %r) via SSO (%r) logout notification (session %r).", + user_id, + device_id, + auth_provider_id, + auth_provider_session_id, + ) + await self._device_handler.delete_devices(user_id, [device_id]) + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index 81fec39659..e4b28ce3df 100644 --- a/synapse/rest/synapse/client/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING from twisted.web.resource import Resource +from synapse.rest.synapse.client.oidc.backchannel_logout_resource import ( + OIDCBackchannelLogoutResource, +) from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource if TYPE_CHECKING: @@ -29,6 +32,7 @@ class OIDCResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) + self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs)) __all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py new file mode 100644 index 0000000000..e07e76855a --- /dev/null +++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py @@ -0,0 +1,35 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class OIDCBackchannelLogoutResource(DirectServeJsonResource): + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_POST(self, request: SynapseRequest) -> None: + await self._oidc_handler.handle_backchannel_logout(request) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0255295317..5167089e03 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._clock.time_msec(), ) + async def invalidate_login_tokens_by_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> None: + """Invalidate login tokens with the given IdP session ID. + + Args: + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider + """ + await self.db_pool.simple_update( + table="login_tokens", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + updatevalues={"used_ts": self._clock.time_msec()}, + desc="invalidate_login_tokens_by_session_id", + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index ebf653d018..847294dc8e 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType -from synapse.api.errors import Codes +from synapse.api.errors import Codes, SynapseError from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -32,8 +33,8 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel +from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER +from tests.server import FakeChannel, make_request from tests.unittest import override_config, skip_unless @@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token: str) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. @@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.reactor.advance(59.0) # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 61 s (just past 1 minute, the time of expiry) self.reactor.advance(2.0) # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 599 s (just shy of 10 minutes, the time of expiry) self.reactor.advance(599.0 - 61.0) # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 601 s (just past 10 minutes, the time of expiry) self.reactor.advance(2.0) # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami( + nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} @@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # and no refresh token self.assertEqual(_table_length("access_tokens"), 0) self.assertEqual(_table_length("refresh_tokens"), 0) + + +def oidc_config( + id: str, with_localpart_template: bool, **kwargs: Any +) -> Dict[str, Any]: + """Sample OIDC provider config used in backchannel logout tests. + + Args: + id: IDP ID for this provider + with_localpart_template: Set to `true` to have a default localpart_template in + the `user_mapping_provider` config and skip the user mapping session + **kwargs: rest of the config + + Returns: + A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of + the HS config + """ + config: Dict[str, Any] = { + "idp_id": id, + "idp_name": id, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + } + + if with_localpart_template: + config["user_mapping_provider"] = { + "config": {"localpart_template": "{{ user.sub }}"} + } + else: + config["user_mapping_provider"] = {"config": {}} + + config.update(kwargs) + + return config + + +@skip_unless(HAS_OIDC, "Requires OIDC") +class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def submit_logout_token(self, logout_token: str) -> FakeChannel: + return self.make_request( + "POST", + "/_synapse/client/oidc/backchannel_logout", + content=f"logout_token={logout_token}", + content_is_form=True, + ) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_simple_logout(self) -> None: + """ + Receiving a logout token should logout the user + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + self.assertNotEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = fake_oidc_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = fake_oidc_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_login(self) -> None: + """ + It should revoke login tokens when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # expect a confirmation page + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now try to exchange the login token + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + # It should have failed + self.assertEqual(channel.code, 403) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=False, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_mapping(self) -> None: + """ + It should stop ongoing user mapping session when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # Expect a user mapping page + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + + # We should have a user_mapping_session cookie + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + user_mapping_session_id = cookies["username_mapping_session"] + + # Getting that session should not raise + session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + self.assertIsNotNone(session) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now it should raise + with self.assertRaises(SynapseError): + self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=False, + ) + ] + } + ) + def test_disabled(self) -> None: + """ + Receiving a logout token should do nothing if it is disabled in the config + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_no_sid(self) -> None: + """ + Receiving a logout token without `sid` during the login should do nothing + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=False + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + "first", + issuer="https://first-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + oidc_config( + "second", + issuer="https://second-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + ] + } + ) + def test_multiple_providers(self) -> None: + """ + It should be able to distinguish login tokens from two different IdPs + """ + first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/") + second_server = self.helper.fake_oidc_server( + issuer="https://second-issuer.com/" + ) + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + first_server, user, with_sid=True, idp_id="oidc-first" + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + second_server, user, with_sid=True, idp_id="oidc-second" + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # `sid` in the fake providers are generated by a counter, so the first grant of + # each provider should give the same SID + self.assertEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = first_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = second_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 967d229223..706399fae5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -553,6 +553,34 @@ class RestHelper: return channel.json_body + def whoami( + self, + access_token: str, + expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK, + ) -> JsonDict: + """Perform a 'whoami' request, which can be a quick way to check for access + token validity + + Args: + access_token: The user token to use during the request + expect_code: The return code to expect from attempting the whoami request + """ + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "account/whoami", + access_token=access_token, + ) + + assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: """Create a ``FakeOidcServer``. @@ -572,6 +600,7 @@ class RestHelper: fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, + idp_id: Optional[str] = None, expected_status: int = 200, ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC @@ -588,7 +617,11 @@ class RestHelper: client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} channel, grant = self.auth_via_oidc( - fake_server, userinfo, client_redirect_url, with_sid=with_sid + fake_server, + userinfo, + client_redirect_url, + with_sid=with_sid, + idp_id=idp_id, ) # expect a confirmation page @@ -623,6 +656,7 @@ class RestHelper: client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, with_sid: bool = False, + idp_id: Optional[str] = None, ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. @@ -648,6 +682,7 @@ class RestHelper: ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. with_sid: if True, generates a random `sid` (OIDC session ID) + idp_id: if set, explicitely chooses one specific IDP Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -665,7 +700,9 @@ class RestHelper: oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) else: # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + oauth_uri = self.initiate_sso_login( + client_redirect_url, cookies, idp_id=idp_id + ) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -742,7 +779,10 @@ class RestHelper: return channel, grant def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + self, + client_redirect_url: Optional[str], + cookies: MutableMapping[str, str], + idp_id: Optional[str] = None, ) -> str: """Make a request to the login-via-sso redirect endpoint, and return the target @@ -753,6 +793,7 @@ class RestHelper: client_redirect_url: the client redirect URL to pass to the login redirect endpoint cookies: any cookies returned will be added to this dict + idp_id: if set, explicitely chooses one specific IDP Returns: the URI that the client gets redirected to (ie, the SSO server) @@ -761,6 +802,12 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url + uri = "/_matrix/client/r0/login/sso/redirect" + if idp_id is not None: + uri = f"{uri}/{idp_id}" + + uri = f"{uri}?{urllib.parse.urlencode(params)}" + # hit the redirect url (which should redirect back to the redirect url. This # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. @@ -768,7 +815,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + uri, ) assert channel.code == 302 diff --git a/tests/server.py b/tests/server.py index 8b1d186219..b1730fcc8d 100644 --- a/tests/server.py +++ b/tests/server.py @@ -362,6 +362,12 @@ def make_request( # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) + # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded + # bodies if the Content-Length header is missing + req.requestHeaders.addRawHeader( + b"Content-Length", str(len(content)).encode("ascii") + ) + if access_token: req.requestHeaders.addRawHeader( b"Authorization", b"Bearer " + access_token.encode("ascii") diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index de134bbc89..1461d23ee8 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -51,6 +51,8 @@ class FakeOidcServer: get_userinfo_handler: Mock post_token_handler: Mock + sid_counter: int = 0 + def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet @@ -146,7 +148,7 @@ class FakeOidcServer: return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self._clock.time() + now = int(self._clock.time()) id_token = { **grant.userinfo, "iss": self.issuer, @@ -166,6 +168,26 @@ class FakeOidcServer: return self._sign(id_token) + def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str: + now = int(self._clock.time()) + logout_token = { + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "jti": random_string(10), + "events": { + "http://schemas.openid.net/event/backchannel-logout": {}, + }, + } + + if grant.sid is not None: + logout_token["sid"] = grant.sid + + if "sub" in grant.userinfo: + logout_token["sub"] = grant.userinfo["sub"] + + return self._sign(logout_token) + def id_token_override(self, overrides: dict): """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -183,7 +205,8 @@ class FakeOidcServer: code = random_string(10) sid = None if with_sid: - sid = random_string(10) + sid = str(self.sid_counter) + self.sid_counter += 1 grant = FakeAuthorizationGrant( userinfo=userinfo, -- cgit 1.5.1 From dbfc9b803ee32f7b31c2b5ccbc53a1bfcaa95983 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 20:31:43 +0000 Subject: Fix dehydrated device REST checks (#14336) --- changelog.d/14336.bugfix | 1 + synapse/rest/client/devices.py | 5 ++--- tests/rest/client/test_devices.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14336.bugfix (limited to 'tests') diff --git a/changelog.d/14336.bugfix b/changelog.d/14336.bugfix new file mode 100644 index 0000000000..d44ff1bbc7 --- /dev/null +++ b/changelog.d/14336.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70 where clients were unable to PUT new [dehydrated devices](https://github.com/matrix-org/matrix-spec-proposals/pull/2697). diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 90828c95c4..8f3cbd4ea2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -231,7 +231,7 @@ class DehydratedDeviceServlet(RestServlet): } } - PUT /org.matrix.msc2697/dehydrated_device + PUT /org.matrix.msc2697.v2/dehydrated_device Content-Type: application/json { @@ -271,7 +271,6 @@ class DehydratedDeviceServlet(RestServlet): raise errors.NotFoundError("No dehydrated device available") class PutBody(RequestBodyModel): - device_id: StrictStr device_data: DehydratedDeviceDataModel initial_device_display_name: Optional[StrictStr] @@ -281,7 +280,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission.device_data, + submission.device_data.dict(), submission.initial_device_display_name, ) return 200, {"device_id": device_id} diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index aa98222434..d80eea17d3 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase): self.reactor.advance(43200) self.get_success(self.handler.get_device(user_id, "abc")) self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) + + +class DehydratedDeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + + def test_PUT(self) -> None: + """Sanity-check that we can PUT a dehydrated device. + + Detects https://github.com/matrix-org/synapse/issues/14334. + """ + alice = self.register_user("alice", "correcthorse") + token = self.login(alice, "correcthorse") + + # Have alice update their device list + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device", + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device", + } + }, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + device_id = channel.json_body.get("device_id") + self.assertIsInstance(device_id, str) -- cgit 1.5.1 From 5905ba12d0c652b77ffc677196c7e31addb0eedf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 13:07:54 +0000 Subject: Run trial tests against Python 3.11 (#13812) --- .ci/scripts/calculate_jobs.py | 10 +- .github/workflows/tests.yml | 2 +- changelog.d/13812.misc | 1 + poetry.lock | 222 ++++++++++++++++++------------ tests/federation/transport/test_client.py | 11 ++ 5 files changed, 149 insertions(+), 97 deletions(-) create mode 100644 changelog.d/13812.misc (limited to 'tests') diff --git a/.ci/scripts/calculate_jobs.py b/.ci/scripts/calculate_jobs.py index c2198e0dd4..f82eec231a 100755 --- a/.ci/scripts/calculate_jobs.py +++ b/.ci/scripts/calculate_jobs.py @@ -33,7 +33,7 @@ IS_PR = os.environ["GITHUB_REF"].startswith("refs/pull/") trial_sqlite_tests = [ { - "python-version": "3.7", + "python-version": "3.11", "database": "sqlite", "extras": "all", } @@ -46,13 +46,13 @@ if not IS_PR: "database": "sqlite", "extras": "all", } - for version in ("3.8", "3.9", "3.10") + for version in ("3.8", "3.9", "3.10", "3.11") ) trial_postgres_tests = [ { - "python-version": "3.7", + "python-version": "3.11", "database": "postgres", "postgres-version": "10", "extras": "all", @@ -62,7 +62,7 @@ trial_postgres_tests = [ if not IS_PR: trial_postgres_tests.append( { - "python-version": "3.10", + "python-version": "3.11", "database": "postgres", "postgres-version": "14", "extras": "all", @@ -71,7 +71,7 @@ if not IS_PR: trial_no_extra_tests = [ { - "python-version": "3.7", + "python-version": "3.11", "database": "sqlite", "extras": "", } diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ff5cf0c534..27fef6b5bd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -393,7 +393,7 @@ jobs: - python-version: "3.7" postgres-version: "10" - - python-version: "3.10" + - python-version: "3.11" postgres-version: "14" services: diff --git a/changelog.d/13812.misc b/changelog.d/13812.misc new file mode 100644 index 0000000000..667fdee2b7 --- /dev/null +++ b/changelog.d/13812.misc @@ -0,0 +1 @@ +Run unit tests against Python 3.11. diff --git a/poetry.lock b/poetry.lock index c01cfcfa58..ebabd3b833 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,7 +25,7 @@ cryptography = ">=3.2" [[package]] name = "automat" -version = "20.2.0" +version = "22.10.0" description = "Self-service finite-state machines for the programmer on the go." category = "main" optional = false @@ -671,12 +671,16 @@ python-versions = "*" [[package]] name = "pillow" -version = "9.0.1" +version = "9.2.0" description = "Python Imaging Library (Fork)" category = "main" optional = false python-versions = ">=3.7" +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] + [[package]] name = "pkginfo" version = "1.8.2" @@ -1534,7 +1538,7 @@ python-versions = "*" [[package]] name = "wrapt" -version = "1.13.3" +version = "1.14.1" description = "Module for decorators, wrappers and monkey patching." category = "dev" optional = false @@ -1646,8 +1650,8 @@ Authlib = [ {file = "Authlib-1.1.0.tar.gz", hash = "sha256:0a270c91409fc2b7b0fbee6996e09f2ee3187358762111a9a4225c874b94e891"}, ] automat = [ - {file = "Automat-20.2.0-py2.py3-none-any.whl", hash = "sha256:b6feb6455337df834f6c9962d6ccf771515b7d939bca142b29c20c2376bc6111"}, - {file = "Automat-20.2.0.tar.gz", hash = "sha256:7979803c74610e11ef0c0d68a2942b152df52da55336e0c9d58daf1831cbdf33"}, + {file = "Automat-22.10.0-py2.py3-none-any.whl", hash = "sha256:c3164f8742b9dc440f3682482d32aaff7bb53f71740dd018533f9de286b64180"}, + {file = "Automat-22.10.0.tar.gz", hash = "sha256:e56beb84edad19dcc11d30e8d9b895f75deeb5ef5e96b84a467066b3b84bb04e"}, ] bcrypt = [ {file = "bcrypt-4.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:b1023030aec778185a6c16cf70f359cbb6e0c289fd564a7cfa29e727a1c38f8f"}, @@ -2251,41 +2255,64 @@ phonenumbers = [ {file = "phonenumbers-8.12.56.tar.gz", hash = "sha256:82a4f226c930d02dcdf6d4b29e4cfd8678991fe65c2efd5fdd143557186f0868"}, ] pillow = [ - {file = "Pillow-9.0.1-1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a5d24e1d674dd9d72c66ad3ea9131322819ff86250b30dc5821cbafcfa0b96b4"}, - {file = "Pillow-9.0.1-1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2632d0f846b7c7600edf53c48f8f9f1e13e62f66a6dbc15191029d950bfed976"}, - {file = "Pillow-9.0.1-1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b9618823bd237c0d2575283f2939655f54d51b4527ec3972907a927acbcc5bfc"}, - {file = "Pillow-9.0.1-cp310-cp310-macosx_10_10_universal2.whl", hash = "sha256:9bfdb82cdfeccec50aad441afc332faf8606dfa5e8efd18a6692b5d6e79f00fd"}, - {file = "Pillow-9.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5100b45a4638e3c00e4d2320d3193bdabb2d75e79793af7c3eb139e4f569f16f"}, - {file = "Pillow-9.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:528a2a692c65dd5cafc130de286030af251d2ee0483a5bf50c9348aefe834e8a"}, - {file = "Pillow-9.0.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f29d831e2151e0b7b39981756d201f7108d3d215896212ffe2e992d06bfe049"}, - {file = "Pillow-9.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:855c583f268edde09474b081e3ddcd5cf3b20c12f26e0d434e1386cc5d318e7a"}, - {file = "Pillow-9.0.1-cp310-cp310-win32.whl", hash = "sha256:d9d7942b624b04b895cb95af03a23407f17646815495ce4547f0e60e0b06f58e"}, - {file = "Pillow-9.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:81c4b81611e3a3cb30e59b0cf05b888c675f97e3adb2c8672c3154047980726b"}, - {file = "Pillow-9.0.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:413ce0bbf9fc6278b2d63309dfeefe452835e1c78398efb431bab0672fe9274e"}, - {file = "Pillow-9.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80fe64a6deb6fcfdf7b8386f2cf216d329be6f2781f7d90304351811fb591360"}, - {file = "Pillow-9.0.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cef9c85ccbe9bee00909758936ea841ef12035296c748aaceee535969e27d31b"}, - {file = "Pillow-9.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d19397351f73a88904ad1aee421e800fe4bbcd1aeee6435fb62d0a05ccd1030"}, - {file = "Pillow-9.0.1-cp37-cp37m-win32.whl", hash = "sha256:d21237d0cd37acded35154e29aec853e945950321dd2ffd1a7d86fe686814669"}, - {file = "Pillow-9.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:ede5af4a2702444a832a800b8eb7f0a7a1c0eed55b644642e049c98d589e5092"}, - {file = "Pillow-9.0.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:b5b3f092fe345c03bca1e0b687dfbb39364b21ebb8ba90e3fa707374b7915204"}, - {file = "Pillow-9.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:335ace1a22325395c4ea88e00ba3dc89ca029bd66bd5a3c382d53e44f0ccd77e"}, - {file = "Pillow-9.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db6d9fac65bd08cea7f3540b899977c6dee9edad959fa4eaf305940d9cbd861c"}, - {file = "Pillow-9.0.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f154d173286a5d1863637a7dcd8c3437bb557520b01bddb0be0258dcb72696b5"}, - {file = "Pillow-9.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14d4b1341ac07ae07eb2cc682f459bec932a380c3b122f5540432d8977e64eae"}, - {file = "Pillow-9.0.1-cp38-cp38-win32.whl", hash = "sha256:effb7749713d5317478bb3acb3f81d9d7c7f86726d41c1facca068a04cf5bb4c"}, - {file = "Pillow-9.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:7f7609a718b177bf171ac93cea9fd2ddc0e03e84d8fa4e887bdfc39671d46b00"}, - {file = "Pillow-9.0.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:80ca33961ced9c63358056bd08403ff866512038883e74f3a4bf88ad3eb66838"}, - {file = "Pillow-9.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c3c33ac69cf059bbb9d1a71eeaba76781b450bc307e2291f8a4764d779a6b28"}, - {file = "Pillow-9.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12875d118f21cf35604176872447cdb57b07126750a33748bac15e77f90f1f9c"}, - {file = "Pillow-9.0.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:514ceac913076feefbeaf89771fd6febde78b0c4c1b23aaeab082c41c694e81b"}, - {file = "Pillow-9.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3c5c79ab7dfce6d88f1ba639b77e77a17ea33a01b07b99840d6ed08031cb2a7"}, - {file = "Pillow-9.0.1-cp39-cp39-win32.whl", hash = "sha256:718856856ba31f14f13ba885ff13874be7fefc53984d2832458f12c38205f7f7"}, - {file = "Pillow-9.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:f25ed6e28ddf50de7e7ea99d7a976d6a9c415f03adcaac9c41ff6ff41b6d86ac"}, - {file = "Pillow-9.0.1-pp37-pypy37_pp73-macosx_10_10_x86_64.whl", hash = "sha256:011233e0c42a4a7836498e98c1acf5e744c96a67dd5032a6f666cc1fb97eab97"}, - {file = "Pillow-9.0.1-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253e8a302a96df6927310a9d44e6103055e8fb96a6822f8b7f514bb7ef77de56"}, - {file = "Pillow-9.0.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6295f6763749b89c994fcb6d8a7f7ce03c3992e695f89f00b741b4580b199b7e"}, - {file = "Pillow-9.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a9f44cd7e162ac6191491d7249cceb02b8116b0f7e847ee33f739d7cb1ea1f70"}, - {file = "Pillow-9.0.1.tar.gz", hash = "sha256:6c8bc8238a7dfdaf7a75f5ec5a663f4173f8c367e5a39f87e720495e1eed75fa"}, + {file = "Pillow-9.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:a9c9bc489f8ab30906d7a85afac4b4944a572a7432e00698a7239f44a44e6efb"}, + {file = "Pillow-9.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:510cef4a3f401c246cfd8227b300828715dd055463cdca6176c2e4036df8bd4f"}, + {file = "Pillow-9.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7888310f6214f19ab2b6df90f3f06afa3df7ef7355fc025e78a3044737fab1f5"}, + {file = "Pillow-9.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:831e648102c82f152e14c1a0938689dbb22480c548c8d4b8b248b3e50967b88c"}, + {file = "Pillow-9.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cc1d2451e8a3b4bfdb9caf745b58e6c7a77d2e469159b0d527a4554d73694d1"}, + {file = "Pillow-9.2.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:136659638f61a251e8ed3b331fc6ccd124590eeff539de57c5f80ef3a9594e58"}, + {file = "Pillow-9.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6e8c66f70fb539301e064f6478d7453e820d8a2c631da948a23384865cd95544"}, + {file = "Pillow-9.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:37ff6b522a26d0538b753f0b4e8e164fdada12db6c6f00f62145d732d8a3152e"}, + {file = "Pillow-9.2.0-cp310-cp310-win32.whl", hash = "sha256:c79698d4cd9318d9481d89a77e2d3fcaeff5486be641e60a4b49f3d2ecca4e28"}, + {file = "Pillow-9.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:254164c57bab4b459f14c64e93df11eff5ded575192c294a0c49270f22c5d93d"}, + {file = "Pillow-9.2.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:adabc0bce035467fb537ef3e5e74f2847c8af217ee0be0455d4fec8adc0462fc"}, + {file = "Pillow-9.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:336b9036127eab855beec9662ac3ea13a4544a523ae273cbf108b228ecac8437"}, + {file = "Pillow-9.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50dff9cc21826d2977ef2d2a205504034e3a4563ca6f5db739b0d1026658e004"}, + {file = "Pillow-9.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb6259196a589123d755380b65127ddc60f4c64b21fc3bb46ce3a6ea663659b0"}, + {file = "Pillow-9.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b0554af24df2bf96618dac71ddada02420f946be943b181108cac55a7a2dcd4"}, + {file = "Pillow-9.2.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:15928f824870535c85dbf949c09d6ae7d3d6ac2d6efec80f3227f73eefba741c"}, + {file = "Pillow-9.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:bdd0de2d64688ecae88dd8935012c4a72681e5df632af903a1dca8c5e7aa871a"}, + {file = "Pillow-9.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5b87da55a08acb586bad5c3aa3b86505f559b84f39035b233d5bf844b0834b1"}, + {file = "Pillow-9.2.0-cp311-cp311-win32.whl", hash = "sha256:b6d5e92df2b77665e07ddb2e4dbd6d644b78e4c0d2e9272a852627cdba0d75cf"}, + {file = "Pillow-9.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:6bf088c1ce160f50ea40764f825ec9b72ed9da25346216b91361eef8ad1b8f8c"}, + {file = "Pillow-9.2.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:2c58b24e3a63efd22554c676d81b0e57f80e0a7d3a5874a7e14ce90ec40d3069"}, + {file = "Pillow-9.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eef7592281f7c174d3d6cbfbb7ee5984a671fcd77e3fc78e973d492e9bf0eb3f"}, + {file = "Pillow-9.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dcd7b9c7139dc8258d164b55696ecd16c04607f1cc33ba7af86613881ffe4ac8"}, + {file = "Pillow-9.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a138441e95562b3c078746a22f8fca8ff1c22c014f856278bdbdd89ca36cff1b"}, + {file = "Pillow-9.2.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:93689632949aff41199090eff5474f3990b6823404e45d66a5d44304e9cdc467"}, + {file = "Pillow-9.2.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:f3fac744f9b540148fa7715a435d2283b71f68bfb6d4aae24482a890aed18b59"}, + {file = "Pillow-9.2.0-cp37-cp37m-win32.whl", hash = "sha256:fa768eff5f9f958270b081bb33581b4b569faabf8774726b283edb06617101dc"}, + {file = "Pillow-9.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:69bd1a15d7ba3694631e00df8de65a8cb031911ca11f44929c97fe05eb9b6c1d"}, + {file = "Pillow-9.2.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:030e3460861488e249731c3e7ab59b07c7853838ff3b8e16aac9561bb345da14"}, + {file = "Pillow-9.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:74a04183e6e64930b667d321524e3c5361094bb4af9083db5c301db64cd341f3"}, + {file = "Pillow-9.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d33a11f601213dcd5718109c09a52c2a1c893e7461f0be2d6febc2879ec2402"}, + {file = "Pillow-9.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fd6f5e3c0e4697fa7eb45b6e93996299f3feee73a3175fa451f49a74d092b9f"}, + {file = "Pillow-9.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a647c0d4478b995c5e54615a2e5360ccedd2f85e70ab57fbe817ca613d5e63b8"}, + {file = "Pillow-9.2.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:4134d3f1ba5f15027ff5c04296f13328fecd46921424084516bdb1b2548e66ff"}, + {file = "Pillow-9.2.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:bc431b065722a5ad1dfb4df354fb9333b7a582a5ee39a90e6ffff688d72f27a1"}, + {file = "Pillow-9.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:1536ad017a9f789430fb6b8be8bf99d2f214c76502becc196c6f2d9a75b01b76"}, + {file = "Pillow-9.2.0-cp38-cp38-win32.whl", hash = "sha256:2ad0d4df0f5ef2247e27fc790d5c9b5a0af8ade9ba340db4a73bb1a4a3e5fb4f"}, + {file = "Pillow-9.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:ec52c351b35ca269cb1f8069d610fc45c5bd38c3e91f9ab4cbbf0aebc136d9c8"}, + {file = "Pillow-9.2.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ed2c4ef2451de908c90436d6e8092e13a43992f1860275b4d8082667fbb2ffc"}, + {file = "Pillow-9.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ad2f835e0ad81d1689f1b7e3fbac7b01bb8777d5a985c8962bedee0cc6d43da"}, + {file = "Pillow-9.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea98f633d45f7e815db648fd7ff0f19e328302ac36427343e4432c84432e7ff4"}, + {file = "Pillow-9.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7761afe0126d046974a01e030ae7529ed0ca6a196de3ec6937c11df0df1bc91c"}, + {file = "Pillow-9.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a54614049a18a2d6fe156e68e188da02a046a4a93cf24f373bffd977e943421"}, + {file = "Pillow-9.2.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:5aed7dde98403cd91d86a1115c78d8145c83078e864c1de1064f52e6feb61b20"}, + {file = "Pillow-9.2.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:13b725463f32df1bfeacbf3dd197fb358ae8ebcd8c5548faa75126ea425ccb60"}, + {file = "Pillow-9.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:808add66ea764ed97d44dda1ac4f2cfec4c1867d9efb16a33d158be79f32b8a4"}, + {file = "Pillow-9.2.0-cp39-cp39-win32.whl", hash = "sha256:337a74fd2f291c607d220c793a8135273c4c2ab001b03e601c36766005f36885"}, + {file = "Pillow-9.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:fac2d65901fb0fdf20363fbd345c01958a742f2dc62a8dd4495af66e3ff502a4"}, + {file = "Pillow-9.2.0-pp37-pypy37_pp73-macosx_10_10_x86_64.whl", hash = "sha256:ad2277b185ebce47a63f4dc6302e30f05762b688f8dc3de55dbae4651872cdf3"}, + {file = "Pillow-9.2.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c7b502bc34f6e32ba022b4a209638f9e097d7a9098104ae420eb8186217ebbb"}, + {file = "Pillow-9.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d1f14f5f691f55e1b47f824ca4fdcb4b19b4323fe43cc7bb105988cad7496be"}, + {file = "Pillow-9.2.0-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:dfe4c1fedfde4e2fbc009d5ad420647f7730d719786388b7de0999bf32c0d9fd"}, + {file = "Pillow-9.2.0-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:f07f1f00e22b231dd3d9b9208692042e29792d6bd4f6639415d2f23158a80013"}, + {file = "Pillow-9.2.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1802f34298f5ba11d55e5bb09c31997dc0c6aed919658dfdf0198a2fe75d5490"}, + {file = "Pillow-9.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17d4cafe22f050b46d983b71c707162d63d796a1235cdf8b9d7a112e97b15bac"}, + {file = "Pillow-9.2.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:96b5e6874431df16aee0c1ba237574cb6dff1dcb173798faa6a9d8b399a05d0e"}, + {file = "Pillow-9.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:0030fdbd926fb85844b8b92e2f9449ba89607231d3dd597a21ae72dc7fe26927"}, + {file = "Pillow-9.2.0.tar.gz", hash = "sha256:75e636fd3e0fb872693f23ccb8a5ff2cd578801251f3a4f6854c6a5d437d3c04"}, ] pkginfo = [ {file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"}, @@ -2814,57 +2841,70 @@ webencodings = [ {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, ] wrapt = [ - {file = "wrapt-1.13.3-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:e05e60ff3b2b0342153be4d1b597bbcfd8330890056b9619f4ad6b8d5c96a81a"}, - {file = "wrapt-1.13.3-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:85148f4225287b6a0665eef08a178c15097366d46b210574a658c1ff5b377489"}, - {file = "wrapt-1.13.3-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:2dded5496e8f1592ec27079b28b6ad2a1ef0b9296d270f77b8e4a3a796cf6909"}, - {file = "wrapt-1.13.3-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:e94b7d9deaa4cc7bac9198a58a7240aaf87fe56c6277ee25fa5b3aa1edebd229"}, - {file = "wrapt-1.13.3-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:498e6217523111d07cd67e87a791f5e9ee769f9241fcf8a379696e25806965af"}, - {file = "wrapt-1.13.3-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:ec7e20258ecc5174029a0f391e1b948bf2906cd64c198a9b8b281b811cbc04de"}, - {file = "wrapt-1.13.3-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:87883690cae293541e08ba2da22cacaae0a092e0ed56bbba8d018cc486fbafbb"}, - {file = "wrapt-1.13.3-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:f99c0489258086308aad4ae57da9e8ecf9e1f3f30fa35d5e170b4d4896554d80"}, - {file = "wrapt-1.13.3-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:6a03d9917aee887690aa3f1747ce634e610f6db6f6b332b35c2dd89412912bca"}, - {file = "wrapt-1.13.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:936503cb0a6ed28dbfa87e8fcd0a56458822144e9d11a49ccee6d9a8adb2ac44"}, - {file = "wrapt-1.13.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f9c51d9af9abb899bd34ace878fbec8bf357b3194a10c4e8e0a25512826ef056"}, - {file = "wrapt-1.13.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:220a869982ea9023e163ba915077816ca439489de6d2c09089b219f4e11b6785"}, - {file = "wrapt-1.13.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0877fe981fd76b183711d767500e6b3111378ed2043c145e21816ee589d91096"}, - {file = "wrapt-1.13.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:43e69ffe47e3609a6aec0fe723001c60c65305784d964f5007d5b4fb1bc6bf33"}, - {file = "wrapt-1.13.3-cp310-cp310-win32.whl", hash = "sha256:78dea98c81915bbf510eb6a3c9c24915e4660302937b9ae05a0947164248020f"}, - {file = "wrapt-1.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:ea3e746e29d4000cd98d572f3ee2a6050a4f784bb536f4ac1f035987fc1ed83e"}, - {file = "wrapt-1.13.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:8c73c1a2ec7c98d7eaded149f6d225a692caa1bd7b2401a14125446e9e90410d"}, - {file = "wrapt-1.13.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:086218a72ec7d986a3eddb7707c8c4526d677c7b35e355875a0fe2918b059179"}, - {file = "wrapt-1.13.3-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:e92d0d4fa68ea0c02d39f1e2f9cb5bc4b4a71e8c442207433d8db47ee79d7aa3"}, - {file = "wrapt-1.13.3-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:d4a5f6146cfa5c7ba0134249665acd322a70d1ea61732723c7d3e8cc0fa80755"}, - {file = "wrapt-1.13.3-cp35-cp35m-win32.whl", hash = "sha256:8aab36778fa9bba1a8f06a4919556f9f8c7b33102bd71b3ab307bb3fecb21851"}, - {file = "wrapt-1.13.3-cp35-cp35m-win_amd64.whl", hash = "sha256:944b180f61f5e36c0634d3202ba8509b986b5fbaf57db3e94df11abee244ba13"}, - {file = "wrapt-1.13.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:2ebdde19cd3c8cdf8df3fc165bc7827334bc4e353465048b36f7deeae8ee0918"}, - {file = "wrapt-1.13.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:610f5f83dd1e0ad40254c306f4764fcdc846641f120c3cf424ff57a19d5f7ade"}, - {file = "wrapt-1.13.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5601f44a0f38fed36cc07db004f0eedeaadbdcec90e4e90509480e7e6060a5bc"}, - {file = "wrapt-1.13.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:e6906d6f48437dfd80464f7d7af1740eadc572b9f7a4301e7dd3d65db285cacf"}, - {file = "wrapt-1.13.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:766b32c762e07e26f50d8a3468e3b4228b3736c805018e4b0ec8cc01ecd88125"}, - {file = "wrapt-1.13.3-cp36-cp36m-win32.whl", hash = "sha256:5f223101f21cfd41deec8ce3889dc59f88a59b409db028c469c9b20cfeefbe36"}, - {file = "wrapt-1.13.3-cp36-cp36m-win_amd64.whl", hash = "sha256:f122ccd12fdc69628786d0c947bdd9cb2733be8f800d88b5a37c57f1f1d73c10"}, - {file = "wrapt-1.13.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:46f7f3af321a573fc0c3586612db4decb7eb37172af1bc6173d81f5b66c2e068"}, - {file = "wrapt-1.13.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:778fd096ee96890c10ce96187c76b3e99b2da44e08c9e24d5652f356873f6709"}, - {file = "wrapt-1.13.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0cb23d36ed03bf46b894cfec777eec754146d68429c30431c99ef28482b5c1df"}, - {file = "wrapt-1.13.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:96b81ae75591a795d8c90edc0bfaab44d3d41ffc1aae4d994c5aa21d9b8e19a2"}, - {file = "wrapt-1.13.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7dd215e4e8514004c8d810a73e342c536547038fb130205ec4bba9f5de35d45b"}, - {file = "wrapt-1.13.3-cp37-cp37m-win32.whl", hash = "sha256:47f0a183743e7f71f29e4e21574ad3fa95676136f45b91afcf83f6a050914829"}, - {file = "wrapt-1.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fd76c47f20984b43d93de9a82011bb6e5f8325df6c9ed4d8310029a55fa361ea"}, - {file = "wrapt-1.13.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b73d4b78807bd299b38e4598b8e7bd34ed55d480160d2e7fdaabd9931afa65f9"}, - {file = "wrapt-1.13.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ec9465dd69d5657b5d2fa6133b3e1e989ae27d29471a672416fd729b429eb554"}, - {file = "wrapt-1.13.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dd91006848eb55af2159375134d724032a2d1d13bcc6f81cd8d3ed9f2b8e846c"}, - {file = "wrapt-1.13.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ae9de71eb60940e58207f8e71fe113c639da42adb02fb2bcbcaccc1ccecd092b"}, - {file = "wrapt-1.13.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:51799ca950cfee9396a87f4a1240622ac38973b6df5ef7a41e7f0b98797099ce"}, - {file = "wrapt-1.13.3-cp38-cp38-win32.whl", hash = "sha256:4b9c458732450ec42578b5642ac53e312092acf8c0bfce140ada5ca1ac556f79"}, - {file = "wrapt-1.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:7dde79d007cd6dfa65afe404766057c2409316135cb892be4b1c768e3f3a11cb"}, - {file = "wrapt-1.13.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:981da26722bebb9247a0601e2922cedf8bb7a600e89c852d063313102de6f2cb"}, - {file = "wrapt-1.13.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:705e2af1f7be4707e49ced9153f8d72131090e52be9278b5dbb1498c749a1e32"}, - {file = "wrapt-1.13.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:25b1b1d5df495d82be1c9d2fad408f7ce5ca8a38085e2da41bb63c914baadff7"}, - {file = "wrapt-1.13.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:77416e6b17926d953b5c666a3cb718d5945df63ecf922af0ee576206d7033b5e"}, - {file = "wrapt-1.13.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:865c0b50003616f05858b22174c40ffc27a38e67359fa1495605f96125f76640"}, - {file = "wrapt-1.13.3-cp39-cp39-win32.whl", hash = "sha256:0a017a667d1f7411816e4bf214646d0ad5b1da2c1ea13dec6c162736ff25a374"}, - {file = "wrapt-1.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:81bd7c90d28a4b2e1df135bfbd7c23aee3050078ca6441bead44c42483f9ebfb"}, - {file = "wrapt-1.13.3.tar.gz", hash = "sha256:1fea9cd438686e6682271d36f3481a9f3636195578bab9ca3382e2f5f01fc185"}, + {file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1"}, + {file = "wrapt-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320"}, + {file = "wrapt-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2"}, + {file = "wrapt-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4"}, + {file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069"}, + {file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310"}, + {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f"}, + {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656"}, + {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, + {file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, + {file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d"}, + {file = "wrapt-1.14.1-cp35-cp35m-win32.whl", hash = "sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7"}, + {file = "wrapt-1.14.1-cp35-cp35m-win_amd64.whl", hash = "sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00"}, + {file = "wrapt-1.14.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4"}, + {file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1"}, + {file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1"}, + {file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff"}, + {file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d"}, + {file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1"}, + {file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569"}, + {file = "wrapt-1.14.1-cp36-cp36m-win32.whl", hash = "sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed"}, + {file = "wrapt-1.14.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471"}, + {file = "wrapt-1.14.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248"}, + {file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68"}, + {file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d"}, + {file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77"}, + {file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7"}, + {file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015"}, + {file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a"}, + {file = "wrapt-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853"}, + {file = "wrapt-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c"}, + {file = "wrapt-1.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456"}, + {file = "wrapt-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f"}, + {file = "wrapt-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc"}, + {file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1"}, + {file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af"}, + {file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b"}, + {file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0"}, + {file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57"}, + {file = "wrapt-1.14.1-cp38-cp38-win32.whl", hash = "sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5"}, + {file = "wrapt-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d"}, + {file = "wrapt-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383"}, + {file = "wrapt-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7"}, + {file = "wrapt-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86"}, + {file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735"}, + {file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b"}, + {file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3"}, + {file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3"}, + {file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe"}, + {file = "wrapt-1.14.1-cp39-cp39-win32.whl", hash = "sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5"}, + {file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"}, + {file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"}, ] xmlschema = [ {file = "xmlschema-1.10.0-py3-none-any.whl", hash = "sha256:dbd68bded2fef00c19cf37110ca0565eca34cf0b6c9e1d3b62ad0de8cbb582ca"}, diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index dd4d1b56de..b84c74fc0e 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -15,6 +15,8 @@ import json from unittest.mock import Mock +import ijson.common + from synapse.api.room_versions import RoomVersions from synapse.federation.transport.client import SendJoinParser from synapse.util import ExceptionBundle @@ -117,8 +119,17 @@ class SendJoinParserTestCase(TestCase): coro_3 = Mock() coro_3.close = Mock(side_effect=RuntimeError("Couldn't close coro 3")) + original_coros = parser._coros parser._coros = [coro_1, coro_2, coro_3] + # Close the original coroutines. If we don't, when we garbage collect them + # they will throw, failing the test. (Oddly, this only started in CPython 3.11). + for coro in original_coros: + try: + coro.close() + except ijson.common.IncompleteJSONError: + pass + # Send half of the data to the parser parser.write(serialisation[: len(serialisation) // 2]) -- cgit 1.5.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'tests') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.5.1 From a4b1f6456276e62b3f4d6b060c289b6413b8a5c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Nov 2022 18:43:51 +0200 Subject: Fix /refresh endpoint version (#14364) --- changelog.d/14364.bugfix | 1 + synapse/rest/client/login.py | 2 +- tests/rest/client/test_auth.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14364.bugfix (limited to 'tests') diff --git a/changelog.d/14364.bugfix b/changelog.d/14364.bugfix new file mode 100644 index 0000000000..514bf859bb --- /dev/null +++ b/changelog.d/14364.bugfix @@ -0,0 +1 @@ +Fix refresh token endpoint to be under /r0 and /v3 instead of /v1. Contributed by Tulir @ Beeper. diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 7774f1967d..05706b598c 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -536,7 +536,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) + PATTERNS = client_patterns("/refresh$") def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 847294dc8e..208ec44829 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -635,7 +635,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ return self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": refresh_token}, ) @@ -724,7 +724,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -765,7 +765,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -1002,7 +1002,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1012,7 +1012,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1022,7 +1022,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1056,7 +1056,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1068,7 +1068,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( -- cgit 1.5.1 From e980982b59dea38ec10a5c58993d09e02f845d28 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 7 Nov 2022 13:49:31 +0000 Subject: Do not reject `/sync` requests with unrecognised filter fields (#14369) For forward compatibility, Synapse needs to ignore fields it does not recognise instead of raising an error. Fixes #14365. Signed-off-by: Sean Quah --- changelog.d/14369.bugfix | 1 + synapse/api/filtering.py | 8 ++++---- tests/api/test_filtering.py | 21 +++++++++++++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14369.bugfix (limited to 'tests') diff --git a/changelog.d/14369.bugfix b/changelog.d/14369.bugfix new file mode 100644 index 0000000000..e6709f4eec --- /dev/null +++ b/changelog.d/14369.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would raise an error when encountering an unrecognised field in a `/sync` filter, instead of ignoring it for forward compatibility. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 26be377d03..a9888381b4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer FILTER_SCHEMA = { - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility "type": "object", "properties": { "limit": {"type": "number"}, @@ -63,7 +63,7 @@ FILTER_SCHEMA = { } ROOM_FILTER_SCHEMA = { - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility "type": "object", "properties": { "not_rooms": {"$ref": "#/definitions/room_id_array"}, @@ -77,7 +77,7 @@ ROOM_FILTER_SCHEMA = { } ROOM_EVENT_FILTER_SCHEMA = { - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility "type": "object", "properties": { "limit": {"type": "number"}, @@ -143,7 +143,7 @@ USER_FILTER_SCHEMA = { }, }, }, - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility } diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a82c4eed86..d5524d296e 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -46,19 +46,36 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): + # See USER_FILTER_SCHEMA for the filter schema. invalid_filters = [ - {"boom": {}}, + # `account_data` must be a dictionary {"account_data": "Hello World"}, + # `event_fields` entries must not contain backslashes {"event_fields": [r"\\foo"]}, - {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, + # `event_format` must be "client" or "federation" {"event_format": "other"}, + # `not_rooms` must contain valid room IDs {"room": {"not_rooms": ["#foo:pik-test"]}}, + # `senders` must contain valid user IDs {"presence": {"senders": ["@bar;pik.test.com"]}}, ] for filter in invalid_filters: with self.assertRaises(SynapseError): self.filtering.check_valid_filter(filter) + def test_ignores_unknown_filter_fields(self): + # For forward compatibility, we must ignore unknown filter fields. + # See USER_FILTER_SCHEMA for the filter schema. + filters = [ + {"org.matrix.msc9999.future_option": True}, + {"presence": {"org.matrix.msc9999.future_option": True}}, + {"room": {"org.matrix.msc9999.future_option": True}}, + {"room": {"timeline": {"org.matrix.msc9999.future_option": True}}}, + ] + for filter in filters: + self.filtering.check_valid_filter(filter) + # Must not raise. + def test_valid_filters(self): valid_filters = [ { -- cgit 1.5.1 From 7894251bcea7714b47e3849e509ea717bb18e9f5 Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Nov 2022 13:38:50 -0800 Subject: Correctly create power level event during initial room creation (#14361) --- changelog.d/14361.bugfix | 1 + synapse/handlers/room.py | 25 +++++++++++++++++++++++-- tests/rest/client/test_rooms.py | 4 ++-- 3 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14361.bugfix (limited to 'tests') diff --git a/changelog.d/14361.bugfix b/changelog.d/14361.bugfix new file mode 100644 index 0000000000..33ba1d92af --- /dev/null +++ b/changelog.d/14361.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.71.0rc1 where the power level event was incorrectly created during initial room creation. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f10cfca073..66a50bca6e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1080,6 +1080,19 @@ class RoomCreationHandler: for_batch: bool, **kwargs: Any, ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + """ + Creates an event and associated event context. + Args: + etype: the type of event to be created + content: content of the event + for_batch: whether the event is being created for batch persisting. If + bool for_batch is true, this will create an event using the prev_event_ids, + and will create an event context for the event using the parameters state_map + and current_state_group, thus these parameters must be provided in this + case if for_batch is True. The subsequently created event and context + are suitable for being batched up and bulk persisted to the database + with other similarly created events. + """ nonlocal depth nonlocal prev_event @@ -1139,13 +1152,21 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + # we need the state group of the membership event as it is the current state group + event_to_state = ( + await self._storage_controllers.state.get_state_group_for_events( + [member_event_id] + ) + ) + current_state_group = event_to_state[member_event_id] + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: power_event, power_context = await create_event( - EventTypes.PowerLevels, pl_content, False + EventTypes.PowerLevels, pl_content, True ) current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) @@ -1194,7 +1215,7 @@ class RoomCreationHandler: pl_event, pl_context = await create_event( EventTypes.PowerLevels, power_level_content, - False, + True, ) current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1084d4ad9d..e919e089cb 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -715,7 +715,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +728,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(36, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From e9a4343cb2daa55503bb2a2d1431d83bf9773e68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Nov 2022 09:55:34 -0500 Subject: Drop support for Postgres 10 in full text search code. (#14397) --- changelog.d/14397.removal | 1 + synapse/storage/databases/main/search.py | 50 +++++++++++------------ synapse/storage/engines/postgres.py | 16 -------- tests/storage/test_room_search.py | 69 ++++++++------------------------ 4 files changed, 41 insertions(+), 95 deletions(-) create mode 100644 changelog.d/14397.removal (limited to 'tests') diff --git a/changelog.d/14397.removal b/changelog.d/14397.removal new file mode 100644 index 0000000000..e96b3de2bd --- /dev/null +++ b/changelog.d/14397.removal @@ -0,0 +1 @@ +Remove support for PostgreSQL 10. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e9588d1755..3fe433f66c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -463,18 +463,17 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term - tsquery_func = self.database_engine.tsquery_func - sql = f""" - SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank, + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank, room_id, event_id FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) + WHERE vector @@ websearch_to_tsquery('english', ?) """ args = [search_query, search_query] + args - count_sql = f""" + count_sql = """ SELECT room_id, count(*) as count FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) + WHERE vector @@ websearch_to_tsquery('english', ?) """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -523,9 +522,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -604,18 +601,17 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term - tsquery_func = self.database_engine.tsquery_func - sql = f""" - SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank, + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, origin_server_ts, stream_ordering, room_id, event_id FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) AND + WHERE vector @@ websearch_to_tsquery('english', ?) AND """ args = [search_query, search_query] + args - count_sql = f""" + count_sql = """ SELECT room_id, count(*) as count FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) AND + WHERE vector @@ websearch_to_tsquery('english', ?) AND """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -686,9 +682,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -714,7 +708,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase], tsquery_func: str + self, search_query: str, events: List[EventBase] ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -725,7 +719,6 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events - tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -758,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( - _to_postgres_options( - { - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - } + query = ( + "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)" + % ( + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) ) ) txn.execute(query, (value, search_query)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 0c4fd88914..719a517336 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,22 +170,6 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True - @property - def tsquery_func(self) -> str: - """ - Selects a tsquery_* func to use. - - Ref: https://www.postgresql.org/docs/current/textsearch-controls.html - - Returns: - The function name. - """ - # Postgres 11 added support for websearch_to_tsquery. - assert self._version is not None - if self._version >= 110000: - return "websearch_to_tsquery" - return "plainto_tsquery" - def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 868b5bee84..ef850daa73 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import List, Tuple from unittest.case import SkipTest -from unittest.mock import PropertyMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -220,10 +219,8 @@ class MessageSearchTest(HomeserverTestCase): PHRASE = "the quick brown fox jumps over the lazy dog" - # Each entry is a search query, followed by either a boolean of whether it is - # in the phrase OR a tuple of booleans: whether it matches using websearch - # and using plain search. - COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # Each entry is a search query, followed by a boolean of whether it is in the phrase. + COMMON_CASES = [ ("nope", False), ("brown", True), ("quick brown", True), @@ -231,13 +228,13 @@ class MessageSearchTest(HomeserverTestCase): ("quick \t brown", True), ("jump", True), ("brown nope", False), - ('"brown quick"', (False, True)), + ('"brown quick"', False), ('"jumps over"', True), - ('"quick fox"', (False, True)), + ('"quick fox"', False), ("nope OR doublenope", False), - ("furphy OR fox", (True, False)), - ("fox -nope", (True, False)), - ("fox -brown", (False, True)), + ("furphy OR fox", True), + ("fox -nope", True), + ("fox -brown", False), ('"fox" quick', True), ('"quick brown', True), ('" quick "', True), @@ -246,11 +243,11 @@ class MessageSearchTest(HomeserverTestCase): # TODO Test non-ASCII cases. # Case that fail on SQLite. - POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + POSTGRES_CASES = [ # SQLite treats NOT as a binary operator. - ("- fox", (False, True)), - ("- nope", (True, False)), - ('"-fox quick', (False, True)), + ("- fox", False), + ("- nope", True), + ('"-fox quick', False), # PostgreSQL skips stop words. ('"the quick brown"', True), ('"over lazy"', True), @@ -275,7 +272,7 @@ class MessageSearchTest(HomeserverTestCase): if isinstance(main_store.database_engine, PostgresEngine): assert main_store.database_engine._version is not None found = main_store.database_engine._version < 140000 - self.COMMON_CASES.append(('"fox quick', (found, True))) + self.COMMON_CASES.append(('"fox quick', found)) def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" @@ -315,16 +312,10 @@ class MessageSearchTest(HomeserverTestCase): ) def _check_test_cases( - self, - store: DataStore, - cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], - index=0, + self, store: DataStore, cases: List[Tuple[str, bool]] ) -> None: # Run all the test cases versus search_msgs for query, expect_to_contain in cases: - if isinstance(expect_to_contain, tuple): - expect_to_contain = expect_to_contain[index] - result = self.get_success( store.search_msgs([self.room_id], query, ["content.body"]) ) @@ -343,9 +334,6 @@ class MessageSearchTest(HomeserverTestCase): # Run them again versus search_rooms for query, expect_to_contain in cases: - if isinstance(expect_to_contain, tuple): - expect_to_contain = expect_to_contain[index] - result = self.get_success( store.search_rooms([self.room_id], query, ["content.body"], 10) ) @@ -366,38 +354,15 @@ class MessageSearchTest(HomeserverTestCase): """ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. This test is skipped unless the postgres instance supports websearch_to_tsquery. - """ - - store = self.hs.get_datastores().main - if not isinstance(store.database_engine, PostgresEngine): - raise SkipTest("Test only applies when postgres is used as the database") - - if store.database_engine.tsquery_func != "websearch_to_tsquery": - raise SkipTest( - "Test only applies when postgres supporting websearch_to_tsquery is used as the database" - ) - self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) - - def test_postgres_non_web_search_for_phrase(self): - """ - Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't - supported by the current postgres version. + See https://www.postgresql.org/docs/current/textsearch-controls.html """ store = self.hs.get_datastores().main if not isinstance(store.database_engine, PostgresEngine): raise SkipTest("Test only applies when postgres is used as the database") - # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. - with patch( - "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", - new_callable=PropertyMock, - ) as supports_websearch_to_tsquery: - supports_websearch_to_tsquery.return_value = "plainto_tsquery" - self._check_test_cases( - store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 - ) + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) def test_sqlite_search(self): """ @@ -407,4 +372,4 @@ class MessageSearchTest(HomeserverTestCase): if not isinstance(store.database_engine, Sqlite3Engine): raise SkipTest("Test only applies when sqlite is used as the database") - self._check_test_cases(store, self.COMMON_CASES, index=0) + self._check_test_cases(store, self.COMMON_CASES) -- cgit 1.5.1 From 3a4f80f8c6f39c5549c56c044e10b35064d8d22f Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 11 Nov 2022 10:51:49 +0000 Subject: Merge/remove `Slaved*` stores into `WorkerStores` (#14375) --- changelog.d/14375.misc | 1 + synapse/app/admin_cmd.py | 36 ++++++++--- synapse/app/generic_worker.py | 44 ++++++++++---- synapse/replication/slave/storage/devices.py | 79 ------------------------ synapse/replication/slave/storage/events.py | 79 ------------------------ synapse/replication/slave/storage/filtering.py | 35 ----------- synapse/replication/slave/storage/keys.py | 20 ------ synapse/replication/slave/storage/push_rule.py | 35 ----------- synapse/replication/slave/storage/pushers.py | 47 -------------- synapse/storage/databases/main/__init__.py | 35 ----------- synapse/storage/databases/main/devices.py | 81 ++++++++++++++++++++++--- synapse/storage/databases/main/events_worker.py | 16 +++++ synapse/storage/databases/main/filtering.py | 4 +- synapse/storage/databases/main/push_rule.py | 19 ++++-- synapse/storage/databases/main/pusher.py | 41 +++++++++++-- synapse/storage/databases/main/stream.py | 1 + tests/replication/slave/storage/test_events.py | 6 +- 17 files changed, 202 insertions(+), 377 deletions(-) create mode 100644 changelog.d/14375.misc delete mode 100644 synapse/replication/slave/storage/devices.py delete mode 100644 synapse/replication/slave/storage/events.py delete mode 100644 synapse/replication/slave/storage/filtering.py delete mode 100644 synapse/replication/slave/storage/keys.py delete mode 100644 synapse/replication/slave/storage/push_rule.py delete mode 100644 synapse/replication/slave/storage/pushers.py (limited to 'tests') diff --git a/changelog.d/14375.misc b/changelog.d/14375.misc new file mode 100644 index 0000000000..d0369b9b8c --- /dev/null +++ b/changelog.d/14375.misc @@ -0,0 +1 @@ +Cleanup old worker datastore classes. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 3c8c00ea5b..165d1c5db0 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -28,10 +28,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.events import EventBase from synapse.handlers.admin import ExfiltrationWriter -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -40,10 +36,24 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceWorkerStore, ) from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.types import StateMap from synapse.util import SYNAPSE_VERSION from synapse.util.logcontext import LoggingContext @@ -52,17 +62,25 @@ logger = logging.getLogger("synapse.app.admin_cmd") class AdminCmdSlavedStore( - SlavedFilteringStore, - SlavedPushRuleStore, - SlavedEventStore, - SlavedDeviceStore, + FilteringWorkerStore, + DeviceWorkerStore, TagsWorkerStore, DeviceInboxWorkerStore, AccountDataWorkerStore, + PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, - RegistrationWorkerStore, + RoomMemberWorkerStore, + RelationsWorkerStore, + EventFederationWorkerStore, + EventPushActionsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, ReceiptsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, + RegistrationWorkerStore, RoomWorkerStore, ): def __init__( diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cb5892f041..51446b49cd 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -48,12 +48,6 @@ from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client import ( account_data, @@ -101,8 +95,16 @@ from synapse.storage.databases.main.appservice import ( from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.directory import DirectoryWorkerStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.keys import KeyStore from synapse.storage.databases.main.lock import LockStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore @@ -111,17 +113,25 @@ from synapse.storage.databases.main.monthly_active_users import ( ) from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore +from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.room_batch import RoomBatchStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION from synapse.util.httpresourcetree import create_resource_tree @@ -232,26 +242,36 @@ class GenericWorkerSlavedStore( EndToEndRoomKeyStore, PresenceStore, DeviceInboxWorkerStore, - SlavedDeviceStore, - SlavedPushRuleStore, + DeviceWorkerStore, TagsWorkerStore, AccountDataWorkerStore, - SlavedPusherStore, CensorEventsStore, ClientIpWorkerStore, - SlavedEventStore, - SlavedKeyStore, + # KeyStore isn't really safe to use from a worker, but for now we do so and hope that + # the races it creates aren't too bad. + KeyStore, RoomWorkerStore, RoomBatchStore, DirectoryWorkerStore, + PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ProfileWorkerStore, - SlavedFilteringStore, + FilteringWorkerStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, ServerMetricsStore, + PusherWorkerStore, + RoomMemberWorkerStore, + RelationsWorkerStore, + EventFederationWorkerStore, + EventPushActionsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, ReceiptsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, RegistrationWorkerStore, SearchStore, TransactionWorkerStore, diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py deleted file mode 100644 index 6fcade510a..0000000000 --- a/synapse/replication/slave/storage/devices.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Iterable - -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.devices import DeviceWorkerStore - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedDeviceStore(DeviceWorkerStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - self.hs = hs - - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) - self._invalidate_caches_for_devices(token, rows) - elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) - for row in rows: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) - - def _invalidate_caches_for_devices( - self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] - ) -> None: - for row in rows: - # The entities are either user IDs (starting with '@') whose devices - # have changed, or remote servers that we need to tell about - # changes. - if row.entity.startswith("@"): - self._device_list_stream_cache.entity_has_changed(row.entity, token) - self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate((row.entity,)) - self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - - else: - self._device_list_federation_stream_cache.entity_has_changed( - row.entity, token - ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py deleted file mode 100644 index fe47778cb1..0000000000 --- a/synapse/replication/slave/storage/events.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING - -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.event_federation import EventFederationWorkerStore -from synapse.storage.databases.main.event_push_actions import ( - EventPushActionsWorkerStore, -) -from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.databases.main.relations import RelationsWorkerStore -from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.databases.main.signatures import SignatureWorkerStore -from synapse.storage.databases.main.state import StateGroupWorkerStore -from synapse.storage.databases.main.stream import StreamWorkerStore -from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# So, um, we want to borrow a load of functions intended for reading from -# a DataStore, but we don't want to take functions that either write to the -# DataStore or are cached and don't have cache invalidation logic. -# -# Rather than write duplicate versions of those functions, or lift them to -# a common base class, we going to grab the underlying __func__ object from -# the method descriptor on the DataStore and chuck them into our class. - - -class SlavedEventStore( - EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, - StateGroupWorkerStore, - SignatureWorkerStore, - EventsWorkerStore, - UserErasureWorkerStore, - RelationsWorkerStore, -): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py deleted file mode 100644 index c52679cd60..0000000000 --- a/synapse/replication/slave/storage/filtering.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.filtering import FilteringStore - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedFilteringStore(SQLBaseStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - # Filters are immutable so this cache doesn't need to be expired - get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py deleted file mode 100644 index a00b38c512..0000000000 --- a/synapse/replication/slave/storage/keys.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.databases.main.keys import KeyStore - -# KeyStore isn't really safe to use from a worker, but for now we do so and hope that -# the races it creates aren't too bad. - -SlavedKeyStore = KeyStore diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py deleted file mode 100644 index 5e65eaf1e0..0000000000 --- a/synapse/replication/slave/storage/push_rule.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterable - -from synapse.replication.tcp.streams import PushRulesStream -from synapse.storage.databases.main.push_rule import PushRulesWorkerStore - -from .events import SlavedEventStore - - -class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_max_push_rules_stream_id(self) -> int: - return self._push_rules_stream_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(instance_name, token) - for row in rows: - self.get_push_rules_for_user.invalidate((row.user_id,)) - self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py deleted file mode 100644 index 44ed20e424..0000000000 --- a/synapse/replication/slave/storage/pushers.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Any, Iterable - -from synapse.replication.tcp.streams import PushersStream -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.pusher import PusherWorkerStore - -from ._slaved_id_tracker import SlavedIdTracker - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedPusherStore(PusherWorkerStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - self._pushers_id_gen = SlavedIdTracker( # type: ignore - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) - - def get_pushers_stream_token(self) -> int: - return self._pushers_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index cfaedf5e0c..0e47592be3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -26,9 +26,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id -from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -138,41 +136,8 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._device_list_id_gen = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - super().__init__(database, db_conn, hs) - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - def get_device_stream_token(self) -> int: - # TODO: shouldn't this be moved to `DeviceWorkerStore`? - return self._device_list_id_gen.get_current_token() - async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 979dd4e17e..aa58c2adc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import ( TYPE_CHECKING, @@ -39,6 +38,8 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -49,6 +50,11 @@ from synapse.storage.database import ( from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -80,9 +86,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) + if hs.config.worker.worker_app is None: + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) + else: + self._device_list_id_gen = SlavedIdTracker( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) + # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). - device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined] + device_list_max = self._device_list_id_gen.get_current_token() device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( db_conn, "device_lists_stream", @@ -136,6 +165,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + self._invalidate_caches_for_devices(token, rows) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + for row in rows: + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) + + def _invalidate_caches_for_devices( + self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] + ) -> None: + for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) + + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) + + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -677,11 +739,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): }, ) - @abc.abstractmethod - def get_device_stream_token(self) -> int: - """Get the current stream id from the _device_list_id_gen""" - ... - @trace @cancellable async def get_user_devices_from_cache( @@ -1481,6 +1538,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + # Because we have write access, this will be a StreamIdGenerator + # (see DeviceWorkerStore.__init__) + _device_list_id_gen: AbstractStreamIdGenerator + def __init__( self, database: DatabasePool, @@ -1805,7 +1866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context, ) - async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -2044,7 +2105,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [], ) - async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: return await self.db_pool.runInteraction( "add_device_list_outbound_pokes", add_device_list_outbound_pokes_txn, @@ -2058,7 +2119,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates during partial joins. """ - async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.simple_upsert( table="device_lists_remote_pending", keyvalues={ diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 69fea452ad..a79091952a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import AsyncLruCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -233,6 +234,21 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cb9ee08fa8..12f3b601f1 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -24,7 +24,7 @@ from synapse.types import JsonDict from synapse.util.caches.descriptors import cached -class FilteringStore(SQLBaseStore): +class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( self, user_localpart: str, filter_id: Union[int, str] @@ -46,6 +46,8 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) + +class FilteringStore(FilteringWorkerStore): async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index b6c15f29f8..8ae10f6127 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, + Iterable, List, Mapping, Optional, @@ -31,6 +31,7 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -90,8 +91,6 @@ def _load_rules( return filtered_rules -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, PusherWorkerStore, @@ -99,7 +98,6 @@ class PushRulesWorkerStore( ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore, - metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. @@ -136,14 +134,23 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - @abc.abstractmethod def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: int """ - raise NotImplementedError() + return self._push_rules_stream_id_gen.get_current_token() + + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 01206950a9..4a01562d45 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,13 +27,19 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -52,9 +58,21 @@ class PusherWorkerStore(SQLBaseStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) + + if hs.config.worker.worker_app is None: + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) + else: + self._pushers_id_gen = SlavedIdTracker( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", @@ -96,6 +114,16 @@ class PusherWorkerStore(SQLBaseStore): yield PusherConfig(**r) + def get_pushers_stream_token(self) -> int: + return self._pushers_id_gen.get_current_token() + + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == PushersStream.NAME: + self._pushers_id_gen.advance(instance_name, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) + async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str ) -> Iterator[PusherConfig]: @@ -545,8 +573,9 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): - def get_pushers_stream_token(self) -> int: - return self._pushers_id_gen.get_current_token() + # Because we have write access, this will be a StreamIdGenerator + # (see PusherWorkerStore.__init__) + _pushers_id_gen: AbstractStreamIdGenerator async def add_pusher( self, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 09ce855aa8..cc27ec3804 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -415,6 +415,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_room_max_stream_ordering(self) -> int: """Get the stream_ordering of regular events that we have committed up to diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d42e36cdf1..96f3880923 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -21,11 +21,11 @@ from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource -from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.databases.main.event_push_actions import ( NotifCounts, RoomNotifCounts, ) +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -58,9 +58,9 @@ def patch__eq__(cls): return unpatch -class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): +class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = SlavedEventStore + STORE_TYPE = EventsWorkerStore def setUp(self): # Patch up the equality operator for events so that we can check -- cgit 1.5.1 From a3623af74e0af0d2f6cbd37b47dc54a1acd314d5 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 11 Nov 2022 19:38:17 +0400 Subject: Add an Admin API endpoint for looking up users based on 3PID (#14405) --- changelog.d/14405.feature | 1 + docs/admin_api/user_admin_api.md | 39 ++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 25 +++++++++ tests/rest/admin/test_user.py | 107 ++++++++++++++++++++++++++++++++++----- 5 files changed, 161 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14405.feature (limited to 'tests') diff --git a/changelog.d/14405.feature b/changelog.d/14405.feature new file mode 100644 index 0000000000..d3ba89b597 --- /dev/null +++ b/changelog.d/14405.feature @@ -0,0 +1 @@ +Add an [Admin API](https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoint for user lookup based on third-party ID (3PID). Contributed by @ashfame. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c95d6c9b05..880bef4194 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1197,3 +1197,42 @@ Returns a `404` HTTP status code if no user was found, with a response body like ``` _Added in Synapse 1.68.0._ + + +### Find a user based on their Third Party ID (ThreePID or 3PID) + +The API is: + +``` +GET /_synapse/admin/v1/threepid/$medium/users/$address +``` + +When a user matched the given address for the given medium, an HTTP code `200` with a response body like the following is returned: + +```json +{ + "user_id": "@hello:example.org" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `medium` - Kind of third-party ID, either `email` or `msisdn`. +- `address` - Value of the third-party ID. + +The `address` may have characters that are not URL-safe, so it is advised to URL-encode those parameters. + +**Errors** + +Returns a `404` HTTP status code if no user was found, with a response body like this: + +```json +{ + "errcode":"M_NOT_FOUND", + "error":"User not found" +} +``` + +_Added in Synapse 1.72.0._ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 885669f9c7..c62ea22116 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -81,6 +81,7 @@ from synapse.rest.admin.users import ( ShadowBanRestServlet, UserAdminServlet, UserByExternalId, + UserByThreePid, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -277,6 +278,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) UserByExternalId(hs).register(http_server) + UserByThreePid(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 15ac2059aa..1951b8a9f2 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1224,3 +1224,28 @@ class UserByExternalId(RestServlet): raise NotFoundError("User not found") return HTTPStatus.OK, {"user_id": user_id} + + +class UserByThreePid(RestServlet): + """Find a user based on 3PID of a particular medium""" + + PATTERNS = admin_patterns("/threepid/(?P[^/]*)/users/(?P
[^/]*)") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, + request: SynapseRequest, + medium: str, + address: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + user_id = await self._store.get_user_id_by_threepid(medium, address) + + if user_id is None: + raise NotFoundError("User not found") + + return HTTPStatus.OK, {"user_id": user_id} diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 63410ffdf1..e8c9457794 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -41,14 +41,12 @@ from tests.unittest import override_config class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, profile.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = "/_synapse/admin/v1/register" self.registration_handler = Mock() @@ -446,7 +444,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class UsersListTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1108,7 +1105,6 @@ class UserDevicesTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1382,7 +1378,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): class UserRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2803,7 +2798,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): class UserMembershipRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2960,7 +2954,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): class PushersRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3089,7 +3082,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): class UserMediaRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3881,7 +3873,6 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ], ) class WhoisRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3961,7 +3952,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): class ShadowBanRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4042,7 +4032,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): class RateLimitTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4268,7 +4257,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): class AccountDataTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4358,7 +4346,6 @@ class AccountDataTestCase(unittest.HomeserverTestCase): class UsersByExternalIdTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4442,3 +4429,97 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase): {"user_id": self.other_user}, channel.json_body, ) + + +class UsersByThreePidTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.get_success( + self.store.user_add_threepid( + self.other_user, "email", "user@email.com", 1, 1 + ) + ) + self.get_success( + self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1) + ) + + def test_no_auth(self) -> None: + """Try to look up a user without authentication.""" + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + ) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_medium_does_not_exist(self) -> None: + """Tests that both a lookup for a medium that does not exist and a user that + doesn't exist with that third party ID returns a 404""" + # test for unknown medium + url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + # test for unknown user with a known medium + url = "/_synapse/admin/v1/threepid/email/users/unknown" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_success(self) -> None: + """Tests a successful medium + address lookup""" + # test for email medium with encoded value of user@email.com + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) + + # test for msidn medium with encoded value of +1-12345678 + url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) -- cgit 1.5.1 From 1eed795fc56d95df3968e37f3a4db92f24513e15 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 15 Nov 2022 17:35:19 +0000 Subject: Include heroes in partial join responses' state (#14442) * Pull out hero selection logic * Include heroes in partial join response's state * Changelog * Fixup trial test * Remove TODO --- changelog.d/14442.feature | 1 + synapse/federation/federation_server.py | 23 +++++++++++++++++---- synapse/handlers/sync.py | 20 +++---------------- synapse/storage/databases/main/roommember.py | 30 ++++++++++++++++++++++++++++ tests/federation/test_federation_server.py | 11 ++++++---- 5 files changed, 60 insertions(+), 25 deletions(-) create mode 100644 changelog.d/14442.feature (limited to 'tests') diff --git a/changelog.d/14442.feature b/changelog.d/14442.feature new file mode 100644 index 0000000000..917e7edfb3 --- /dev/null +++ b/changelog.d/14442.feature @@ -0,0 +1 @@ +Faster joins: include heroes' membership events in the partial join response, for rooms without a name or canonical alias. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 59e351595b..bb20af6e91 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -74,6 +74,8 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.lock import Lock +from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary +from synapse.storage.roommember import MemberSummary from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results @@ -691,8 +693,9 @@ class FederationServer(FederationBase): state_event_ids: Collection[str] servers_in_room: Optional[Collection[str]] if caller_supports_partial_state: + summary = await self.store.get_room_summary(room_id) state_event_ids = _get_event_ids_for_partial_state_join( - event, prev_state_ids + event, prev_state_ids, summary ) servers_in_room = await self.state.get_hosts_in_room_at_events( room_id, event_ids=event.prev_event_ids() @@ -1495,6 +1498,7 @@ class FederationHandlerRegistry: def _get_event_ids_for_partial_state_join( join_event: EventBase, prev_state_ids: StateMap[str], + summary: Dict[str, MemberSummary], ) -> Collection[str]: """Calculate state to be retuned in a partial_state send_join @@ -1521,8 +1525,19 @@ def _get_event_ids_for_partial_state_join( if current_membership_event_id is not None: state_event_ids.add(current_membership_event_id) - # TODO: return a few more members: - # - those with invites - # - those that are kicked? / banned + name_id = prev_state_ids.get((EventTypes.Name, "")) + canonical_alias_id = prev_state_ids.get((EventTypes.CanonicalAlias, "")) + if not name_id and not canonical_alias_id: + # Also include the hero members of the room (for DM rooms without a title). + # To do this properly, we should select the correct subset of membership events + # from `prev_state_ids`. Instead, we are lazier and use the (cached) + # `get_room_summary` function, which is based on the current state of the room. + # This introduces races; we choose to ignore them because a) they should be rare + # and b) even if it's wrong, joining servers will get the full state eventually. + heroes = extract_heroes_from_room_summary(summary, join_event.state_key) + for hero in heroes: + membership_event_id = prev_state_ids.get((EventTypes.Member, hero)) + if membership_event_id: + state_event_ids.add(membership_event_id) return state_event_ids diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1db5d68021..259456b55d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -41,6 +41,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts +from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -805,18 +806,6 @@ class SyncHandler: if canonical_alias and canonical_alias.content.get("alias"): return summary - me = sync_config.user.to_string() - - joined_user_ids = [ - r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me - ] - invited_user_ids = [ - r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me - ] - gone_user_ids = [ - r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me - ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] - # FIXME: only build up a member_ids list for our heroes member_ids = {} for membership in ( @@ -828,11 +817,8 @@ class SyncHandler: for user_id, event_id in details.get(membership, empty_ms).members: member_ids[user_id] = event_id - # FIXME: order by stream ordering rather than as returned by SQL - if joined_user_ids or invited_user_ids: - summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5] - else: - summary["m.heroes"] = sorted(gone_user_ids)[0:5] + me = sync_config.user.to_string() + summary["m.heroes"] = extract_heroes_from_room_summary(details, me) if not sync_config.filter_collection.lazy_load_members(): return summary diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e56a13f21e..f02c1d7ea7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1517,6 +1517,36 @@ class RoomMemberStore( await self.db_pool.runInteraction("forget_membership", f) +def extract_heroes_from_room_summary( + details: Mapping[str, MemberSummary], me: str +) -> List[str]: + """Determine the users that represent a room, from the perspective of the `me` user. + + The rules which say which users we select are specified in the "Room Summary" + section of + https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync + + Returns a list (possibly empty) of heroes' mxids. + """ + empty_ms = MemberSummary([], 0) + + joined_user_ids = [ + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me + ] + invited_user_ids = [ + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me + ] + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] + + # FIXME: order by stream ordering rather than as returned by SQL + if joined_user_ids or invited_user_ids: + return sorted(joined_user_ids + invited_user_ids)[0:5] + else: + return sorted(gone_user_ids)[0:5] + + @attr.s(slots=True, auto_attribs=True) class _JoinedHostsCache: """The cached data used by the `_get_joined_hosts_cache`.""" diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 3a6ef221ae..177e5b5afc 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -212,7 +212,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): self.assertEqual(r[("m.room.member", joining_user)].membership, "join") @override_config({"experimental_features": {"msc3706_enabled": True}}) - def test_send_join_partial_state(self): + def test_send_join_partial_state(self) -> None: """When MSC3706 support is enabled, /send_join should return partial state""" joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME join_result = self._make_join(joining_user) @@ -240,6 +240,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ("m.room.power_levels", ""), ("m.room.join_rules", ""), ("m.room.history_visibility", ""), + # Users included here because they're heroes. + ("m.room.member", "@kermit:test"), + ("m.room.member", "@fozzie:test"), ], ) @@ -249,9 +252,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ] self.assertCountEqual( returned_auth_chain_events, - [ - ("m.room.member", "@kermit:test"), - ], + # TODO: change the test so that we get at least one event in the auth chain + # here. + [], ) # the room should show that the new user is a member -- cgit 1.5.1 From 882277008c7b43ab26e3445ab94a38aa25ad0965 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:01:22 +0000 Subject: Fix background updates failing to add unique indexes on receipts (#14453) As part of the database migration to support threaded receipts, there is a possible window in between `73/08thread_receipts_non_null.sql.postgres` removing the original unique constraints on `receipts_linearized` and `receipts_graph` and the `reeipts_linearized_unique_index` and `receipts_graph_unique_index` background updates from `72/08thread_receipts.sql` completing where the unique constraints on `receipts_linearized` and `receipts_graph` are missing. Any emulated upserts on these tables must therefore be performed with a lock held, otherwise duplicate rows can end up in the tables when there are concurrent emulated upserts. Fix the missing lock. Note that emulated upserts no longer happen by default on sqlite, since the minimum supported version of sqlite supports native upserts by default now. Finally, clean up any duplicate receipts that may have crept in before trying to create the `receipts_graph_unique_index` and `receipts_linearized_unique_index` unique indexes. Signed-off-by: Sean Quah --- changelog.d/14453.bugfix | 1 + synapse/storage/databases/main/receipts.py | 171 ++++++++++++++++++--- tests/storage/databases/main/test_receipts.py | 209 ++++++++++++++++++++++++++ 3 files changed, 357 insertions(+), 24 deletions(-) create mode 100644 changelog.d/14453.bugfix create mode 100644 tests/storage/databases/main/test_receipts.py (limited to 'tests') diff --git a/changelog.d/14453.bugfix b/changelog.d/14453.bugfix new file mode 100644 index 0000000000..4969e5450c --- /dev/null +++ b/changelog.d/14453.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where the background updates to add non-thread unique indexes on receipts could fail when upgrading from 1.67.0 or earlier. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dc6989527e..fbf27497ec 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,24 +113,6 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - self.db_pool.updates.register_background_index_update( - "receipts_linearized_unique_index", - index_name="receipts_linearized_unique_index", - table="receipts_linearized", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - - self.db_pool.updates.register_background_index_update( - "receipts_graph_unique_index", - index_name="receipts_graph_unique_index", - table="receipts_graph", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -702,9 +684,6 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_linearized has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) return rx_ts @@ -862,14 +841,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_graph has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) class ReceiptsBackgroundUpdateStore(SQLBaseStore): POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index" + RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index" def __init__( self, @@ -883,6 +861,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, self._populate_receipt_event_stream_ordering, ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_linearized_unique_index, + ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_graph_unique_index, + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int @@ -938,6 +924,143 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _create_receipts_index(self, index_name: str, table: str) -> None: + """Adds a unique index on `(room_id, receipt_type, user_id)` to the given + receipts table, for non-thread receipts.""" + + def _create_index(conn: LoggingDatabaseConnection) -> None: + conn.rollback() + + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=True) + + try: + c = conn.cursor() + + # Now that the duplicates are gone, we can create the index. + concurrently = ( + "CONCURRENTLY" + if isinstance(self.database_engine, PostgresEngine) + else "" + ) + sql = f""" + CREATE UNIQUE INDEX {concurrently} {index_name} + ON {table}(room_id, receipt_type, user_id) + WHERE thread_id IS NULL + """ + c.execute(sql) + finally: + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=False) + + await self.db_pool.runWithConnection(_create_index) + + async def _background_receipts_linearized_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT MAX(stream_id), room_id, receipt_type, user_id + FROM receipts_linearized + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) + + # Then remove duplicate receipts, keeping the one with the highest + # `stream_id`. There should only be a single receipt with any given + # `stream_id`. + for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id < ? + """ + txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_linearized_unique_index", + "receipts_linearized", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + + async def _background_receipts_graph_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT room_id, receipt_type, user_id FROM receipts_graph + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[str, str, str]], list(txn)) + + # Then remove all duplicate receipts. + # We could be clever and try to keep the latest receipt out of every set of + # duplicates, but it's far simpler to remove them all. + for room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_graph + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL + """ + txn.execute(sql, (room_id, receipt_type, user_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_graph_unique_index", + "receipts_graph", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py new file mode 100644 index 0000000000..c4f12d81d7 --- /dev/null +++ b/tests/storage/databases/main/test_receipts.py @@ -0,0 +1,209 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Sequence, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store = hs.get_datastores().main + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + self.other_room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def _test_background_receipts_unique_index( + self, + update_name: str, + index_name: str, + table: str, + receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], + expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], + ): + """Test that the background update to uniqueify non-thread receipts in + the given receipts table works properly. + + Args: + update_name: The name of the background update to test. + index_name: The name of the index that the background update creates. + table: The table of receipts that the background update fixes. + receipts: The test data containing duplicate receipts. + A list of receipt rows to insert, grouped by + `(room_id, receipt_type, user_id)`. + expected_unique_receipts: A dictionary of `(room_id, receipt_type, user_id)` + keys and expected receipt key-values after duplicate receipts have been + removed. + """ + # First, undo the background update. + def drop_receipts_unique_index(txn: LoggingTransaction) -> None: + txn.execute(f"DROP INDEX IF EXISTS {index_name}") + + self.get_success( + self.store.db_pool.runInteraction( + "drop_receipts_unique_index", + drop_receipts_unique_index, + ) + ) + + # Populate the receipts table, including duplicates. + for (room_id, receipt_type, user_id), rows in receipts.items(): + for row in rows: + self.get_success( + self.store.db_pool.simple_insert( + table, + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "thread_id": None, + "data": "{}", + **row, + }, + ) + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": update_name, + "progress_json": "{}", + }, + ) + ) + + self.store.db_pool.updates._all_done = False + + self.wait_for_background_updates() + + # Check that the remaining receipts match expectations. + for ( + room_id, + receipt_type, + user_id, + ), expected_row in expected_unique_receipts.items(): + # Include the receipt key in the returned columns, for more informative + # assertion messages. + columns = ["room_id", "receipt_type", "user_id"] + if expected_row is not None: + columns += expected_row.keys() + + rows = self.get_success( + self.store.db_pool.simple_select_list( + table=table, + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + # `simple_select_onecol` does not support NULL filters, + # so skip the filter on `thread_id`. + }, + retcols=columns, + desc="get_receipt", + ) + ) + + if expected_row is not None: + self.assertEqual( + len(rows), + 1, + f"Background update did not leave behind latest receipt in {table}", + ) + self.assertEqual( + rows[0], + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + **expected_row, + }, + ) + else: + self.assertEqual( + len(rows), + 0, + f"Background update did not remove all duplicate receipts from {table}", + ) + + def test_background_receipts_linearized_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_linearized` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_linearized_unique_index", + "receipts_linearized_unique_index", + "receipts_linearized", + receipts={ + (self.room_id, "m.read", self.user_id): [ + {"stream_id": 5, "event_id": "$some_event"}, + {"stream_id": 6, "event_id": "$some_event"}, + ], + (self.other_room_id, "m.read", self.user_id): [ + {"stream_id": 7, "event_id": "$some_event"} + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): {"stream_id": 6}, + (self.other_room_id, "m.read", self.user_id): {"stream_id": 7}, + }, + ) + + def test_background_receipts_graph_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_graph` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_graph_unique_index", + "receipts_graph_unique_index", + "receipts_graph", + receipts={ + (self.room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + }, + { + "event_ids": '["$some_event"]', + }, + ], + (self.other_room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + } + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): None, + (self.other_room_id, "m.read", self.user_id): { + "event_ids": '["$some_event"]' + }, + }, + ) -- cgit 1.5.1 From d8cc86eff484b6f570f55a5badb337080c6e4dcd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Nov 2022 10:25:24 -0500 Subject: Remove redundant types from comments. (#14412) Remove type hints from comments which have been added as Python type hints. This helps avoid drift between comments and reality, as well as removing redundant information. Also adds some missing type hints which were simple to fill in. --- changelog.d/14412.misc | 1 + synapse/api/errors.py | 2 +- synapse/config/logger.py | 5 ++- synapse/crypto/keyring.py | 9 +++-- synapse/events/__init__.py | 3 +- synapse/federation/transport/client.py | 11 +++--- synapse/federation/transport/server/_base.py | 4 +-- synapse/handlers/e2e_keys.py | 2 +- synapse/handlers/e2e_room_keys.py | 5 +-- synapse/handlers/federation.py | 4 +-- synapse/handlers/identity.py | 2 +- synapse/handlers/oidc.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/saml.py | 4 +-- synapse/http/additional_resource.py | 3 +- synapse/http/federation/matrix_federation_agent.py | 9 +++-- synapse/http/matrixfederationclient.py | 3 +- synapse/http/proxyagent.py | 20 +++++------ synapse/http/server.py | 2 +- synapse/http/site.py | 2 +- synapse/logging/context.py | 39 +++++++++++----------- synapse/logging/opentracing.py | 4 +-- synapse/module_api/__init__.py | 7 ++-- synapse/replication/http/_base.py | 2 +- synapse/rest/admin/users.py | 5 +-- synapse/rest/client/login.py | 2 +- synapse/rest/media/v1/media_repository.py | 4 +-- synapse/rest/media/v1/thumbnailer.py | 4 +-- synapse/server_notices/consent_server_notices.py | 5 ++- .../resource_limits_server_notices.py | 12 ++++--- synapse/storage/controllers/persist_events.py | 5 ++- synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/e2e_room_keys.py | 8 ++--- synapse/storage/databases/main/end_to_end_keys.py | 7 ++-- synapse/storage/databases/main/events.py | 22 ++++++------ synapse/storage/databases/main/events_worker.py | 2 +- .../storage/databases/main/monthly_active_users.py | 8 ++--- synapse/storage/databases/main/registration.py | 6 ++-- synapse/storage/databases/main/room.py | 8 +++-- synapse/storage/databases/main/user_directory.py | 9 +++-- synapse/types.py | 4 +-- synapse/util/async_helpers.py | 3 +- synapse/util/caches/__init__.py | 2 +- synapse/util/caches/deferred_cache.py | 2 +- synapse/util/caches/dictionary_cache.py | 9 ++--- synapse/util/caches/expiringcache.py | 2 +- synapse/util/caches/lrucache.py | 8 ++--- synapse/util/ratelimitutils.py | 2 +- synapse/util/threepids.py | 2 +- synapse/util/wheel_timer.py | 4 +-- tests/http/__init__.py | 7 ++-- tests/replication/slave/storage/test_events.py | 7 ++-- tests/replication/test_multi_media_repo.py | 14 ++++---- .../test_resource_limits_server_notices.py | 10 +++--- tests/unittest.py | 18 +++++----- 55 files changed, 174 insertions(+), 176 deletions(-) create mode 100644 changelog.d/14412.misc (limited to 'tests') diff --git a/changelog.d/14412.misc b/changelog.d/14412.misc new file mode 100644 index 0000000000..4da061d461 --- /dev/null +++ b/changelog.d/14412.misc @@ -0,0 +1 @@ +Remove duplicated type information from type hints. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 400dd12aba..e2cfcea0f2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -713,7 +713,7 @@ class HttpResponseException(CodeMessageException): set to the reason code from the HTTP response. Returns: - SynapseError: + The error converted to a SynapseError. """ # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 94d1150415..5468b963a2 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -317,10 +317,9 @@ def setup_logging( Set up the logging subsystem. Args: - config (LoggingConfig | synapse.config.worker.WorkerConfig): - configuration data + config: configuration data - use_worker_options (bool): True to use the 'worker_log_config' option + use_worker_options: True to use the 'worker_log_config' option instead of 'log_config'. logBeginner: The Twisted logBeginner to use. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c88afb2986..dd9b8089ec 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -213,7 +213,7 @@ class Keyring: def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int]] - ) -> List[defer.Deferred]: + ) -> List["defer.Deferred[None]"]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. @@ -226,10 +226,9 @@ class Keyring: valid. Returns: - List: for each input triplet, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. + For each input triplet, a deferred indicating success or failure to + verify each json object's signature for the given server_name. The + deferreds run their callbacks in the sentinel logcontext. """ return [ run_in_background( diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 030c3ca408..8aca9a3ab9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -597,8 +597,7 @@ def _event_type_from_format_version( format_version: The event format version Returns: - type: A type that can be initialized as per the initializer of - `FrozenEvent` + A type that can be initialized as per the initializer of `FrozenEvent` """ if format_version == EventFormatVersions.ROOM_V1_V2: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd39d4d111..a3cfc701cd 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -280,12 +280,11 @@ class TransportLayerClient: Note that this does not append any events to any graphs. Args: - destination (str): address of remote homeserver - room_id (str): room to join/leave - user_id (str): user to be joined/left - membership (str): one of join/leave - params (dict[str, str|Iterable[str]]): Query parameters to include in the - request. + destination: address of remote homeserver + room_id: room to join/leave + user_id: user to be joined/left + membership: one of join/leave + params: Query parameters to include in the request. Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 1db8009d6c..cdaf0d5de7 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -224,10 +224,10 @@ class BaseFederationServlet: With arguments: - origin (unicode|None): The authenticated server_name of the calling server, + origin (str|None): The authenticated server_name of the calling server, unless REQUIRE_AUTH is set to False and authentication failed. - content (unicode|None): decoded json body of the request. None if the + content (str|None): decoded json body of the request. None if the request was a GET. query (dict[bytes, list[bytes]]): Query params from the request. url-decoded diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a9912c467d..bf1221f523 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -870,7 +870,7 @@ class E2eKeysHandler: - signatures of the user's master key by the user's devices. Args: - user_id (string): the user uploading the keys + user_id: the user uploading the keys signatures (dict[string, dict]): map of devices to signed keys Returns: diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 28dc08c22a..83f53ceb88 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -377,8 +377,9 @@ class E2eRoomKeysHandler: """Deletes a given version of the user's e2e_room_keys backup Args: - user_id(str): the user whose current backup version we're deleting - version(str): the version id of the backup being deleted + user_id: the user whose current backup version we're deleting + version: Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. Raises: NotFoundError: if this backup version doesn't exist """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5fc3b8bc8c..188f0956ef 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1596,8 +1596,8 @@ class FederationHandler: Fetch the complexity of a remote room over federation. Args: - remote_room_hosts (list[str]): The remote servers to ask. - room_id (str): The room ID to ask about. + remote_room_hosts: The remote servers to ask. + room_id: The room ID to ask about. Returns: Dict contains the complexity diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 93d09e9939..848e46eb9b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -711,7 +711,7 @@ class IdentityHandler: inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. - id_access_token (str): The access token to authenticate to the identity + id_access_token: The access token to authenticate to the identity server with Returns: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 867973dcca..41c675f408 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -787,7 +787,7 @@ class OidcProvider: Must include an ``access_token`` field. Returns: - UserInfo: an object representing the user. + an object representing the user. """ logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0066d63987..b7bc787636 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC): """Get the current presence state for multiple users. Returns: - dict: `user_id` -> `UserPresenceState` + A mapping of `user_id` -> `UserPresenceState` """ states = {} missing = [] diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 9602f0d0bb..874860d461 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -441,7 +441,7 @@ class DefaultSamlMappingProvider: client_redirect_url: where the client wants to redirect to Returns: - dict: A dict containing new user attributes. Possible keys: + A dict containing new user attributes. Possible keys: * mxid_localpart (str): Required. The localpart of the user's mxid * displayname (str): The displayname of the user * emails (list[str]): Any emails for the user @@ -483,7 +483,7 @@ class DefaultSamlMappingProvider: Args: config: A dictionary containing configuration options for this provider Returns: - SamlConfig: A custom config object for this module + A custom config object for this module """ # Parse config options and use defaults where necessary mxid_source_attribute = config.get("mxid_source_attribute", "uid") diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 6a9f6635d2..8729630581 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -45,8 +45,7 @@ class AdditionalResource(DirectServeJsonResource): Args: hs: homeserver - handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): - function to be called to handle the request. + handler: function to be called to handle the request. """ super().__init__() self._handler = handler diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 2f0177f1e2..0359231e7d 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -155,11 +155,10 @@ class MatrixFederationAgent: a file for a file upload). Or None if the request is to have no body. Returns: - Deferred[twisted.web.iweb.IResponse]: - fires when the header of the response has been received (regardless of the - response status code). Fails if there is any problem which prevents that - response from being received (including problems that prevent the request - from being sent). + A deferred which fires when the header of the response has been received + (regardless of the response status code). Fails if there is any problem + which prevents that response from being received (including problems that + prevent the request from being sent). """ # We use urlparse as that will set `port` to None if there is no # explicit port. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3c35b1d2c7..b92f1d3d1a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -951,8 +951,7 @@ class MatrixFederationHttpClient: args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The - result will be the decoded JSON body. + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: HttpResponseException: If we get an HTTP response code >= 300 diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 1f8227896f..18899bc6d1 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -34,7 +34,7 @@ from twisted.web.client import ( ) from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS +from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials @@ -134,7 +134,7 @@ class ProxyAgent(_AgentBase): uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. @@ -157,17 +157,17 @@ class ProxyAgent(_AgentBase): a file upload). Or, None if the request is to have no body. Returns: - Deferred[IResponse]: completes when the header of the response has - been received (regardless of the response status code). + A deferred which completes when the header of the response has + been received (regardless of the response status code). - Can fail with: - SchemeNotSupported: if the uri is not http or https + Can fail with: + SchemeNotSupported: if the uri is not http or https - twisted.internet.error.TimeoutError if the server we are connecting - to (proxy or destination) does not accept a connection before - connectTimeout. + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. - ... other things too. + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/http/server.py b/synapse/http/server.py index b26e34bceb..051a1899a0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -267,7 +267,7 @@ class HttpServer(Protocol): request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. This should return either tuple of (code, response), or None. - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ diff --git a/synapse/http/site.py b/synapse/http/site.py index 3dbd541fed..6a1dbf7f33 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -400,7 +400,7 @@ class SynapseRequest(Request): be sure to call finished_processing. Args: - servlet_name (str): the name of the servlet which will be + servlet_name: the name of the servlet which will be processing this request. This is used in the metrics. It is possible to update this afterwards by updating diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 6a08ffed64..f62bea968f 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -117,8 +117,7 @@ class ContextResourceUsage: """Create a new ContextResourceUsage Args: - copy_from (ContextResourceUsage|None): if not None, an object to - copy stats from + copy_from: if not None, an object to copy stats from """ if copy_from is None: self.reset() @@ -162,7 +161,7 @@ class ContextResourceUsage: """Add another ContextResourceUsage's stats to this one's. Args: - other (ContextResourceUsage): the other resource usage object + other: the other resource usage object """ self.ru_utime += other.ru_utime self.ru_stime += other.ru_stime @@ -342,7 +341,7 @@ class LoggingContext: called directly. Returns: - LoggingContext: the current logging context + The current logging context """ warnings.warn( "synapse.logging.context.LoggingContext.current_context() is deprecated " @@ -362,7 +361,8 @@ class LoggingContext: called directly. Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -474,8 +474,7 @@ class LoggingContext: """Get resources used by this logcontext so far. Returns: - ContextResourceUsage: a *copy* of the object tracking resource - usage so far + A *copy* of the object tracking resource usage so far """ # we always return a copy, for consistency res = self._resource_usage.copy() @@ -663,7 +662,8 @@ def current_context() -> LoggingContextOrSentinel: def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: """Set the current logging context in thread local storage Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -700,7 +700,7 @@ def nested_logging_context(suffix: str) -> LoggingContext: suffix: suffix to add to the parent context's 'name'. Returns: - LoggingContext: new logging context. + A new logging context. """ curr_context = current_context() if not curr_context: @@ -898,20 +898,19 @@ def defer_to_thread( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked, and whose threadpool we should use for the - function. + reactor: The reactor in whose main thread the Deferred will be invoked, + and whose threadpool we should use for the function. Normally this will be hs.get_reactor(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) @@ -939,20 +938,20 @@ def defer_to_threadpool( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked. Normally this will be hs.get_reactor(). + reactor: The reactor in whose main thread the Deferred will be invoked. + Normally this will be hs.get_reactor(). - threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for - running `f`. Normally this will be hs.get_reactor().getThreadPool(). + threadpool: The threadpool to use for running `f`. Normally this will be + hs.get_reactor().getThreadPool(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ curr_context = current_context() diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ce5a2a338..b69060854f 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -721,7 +721,7 @@ def inject_header_dict( destination: address of entity receiving the span context. Must be given unless check_destination is False. The context will only be injected if the destination matches the opentracing whitelist - check_destination (bool): If false, destination will be ignored and the context + check_destination: If false, destination will be ignored and the context will always be injected. Note: @@ -780,7 +780,7 @@ def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str destination: the name of the remote server. Returns: - dict: the active span's context if opentracing is enabled, otherwise empty. + the active span's context if opentracing is enabled, otherwise empty. """ if destination and not whitelisted_homeserver(destination): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 30e689d00d..1adc1fd64f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -787,7 +787,7 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - access_token(str): access token + access_token: access token Returns: twisted.internet.defer.Deferred - resolves once the access token @@ -832,7 +832,7 @@ class ModuleApi: **kwargs: named args to be passed to func Returns: - Deferred[object]: result of func + Result of func """ # type-ignore: See https://github.com/python/mypy/issues/8862 return defer.ensureDeferred( @@ -924,8 +924,7 @@ class ModuleApi: to represent 'any') of the room state to acquire. Returns: - twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: - The filtered state events in the room. + The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( self._storage_controllers.state.get_current_state_ids( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 5e661f8c73..3f4d3fc51a 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -153,7 +153,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): argument list. Returns: - dict: If POST/PUT request then dictionary must be JSON serialisable, + If POST/PUT request then dictionary must be JSON serialisable, otherwise must be appropriate for adding as query args. """ return {} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1951b8a9f2..6e0c44be2a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -903,8 +903,9 @@ class PushersRestServlet(RestServlet): @user:server/pushers Returns: - pushers: Dictionary containing pushers information. - total: Number of pushers in dictionary `pushers`. + A dictionary with keys: + pushers: Dictionary containing pushers information. + total: Number of pushers in dictionary `pushers`. """ PATTERNS = admin_patterns("/users/(?P[^/]*)/pushers$") diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 05706b598c..8adced41e5 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -350,7 +350,7 @@ class LoginRestServlet(RestServlet): auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: - result: Dictionary of account information after successful login. + Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 328c0c5477..40b0d39eb2 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -344,8 +344,8 @@ class MediaRepository: download from remote server. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 9b93b9b4f6..a48a4de92a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -138,7 +138,7 @@ class Thumbnailer: """Rescales the image to the given dimensions. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ with self._resize(width, height) as scaled: return self._encode_image(scaled, output_type) @@ -155,7 +155,7 @@ class Thumbnailer: max_height: The largest possible height. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_width = width diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 698ca742ed..94025ba41f 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -113,9 +113,8 @@ def copy_with_str_subst(x: Any, substitutions: Any) -> Any: """Deep-copy a structure, carrying out string substitutions on any strings Args: - x (object): structure to be copied - substitutions (object): substitutions to be made - passed into the - string '%' operator + x: structure to be copied + substitutions: substitutions to be made - passed into the string '%' operator Returns: copy of x diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 3134cd2d3d..a31a2c99a7 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -170,11 +170,13 @@ class ResourceLimitsServerNotices: room_id: The room id of the server notices room Returns: - bool: Is the room currently blocked - list: The list of pinned event IDs that are unrelated to limit blocking - This list can be used as a convenience in the case where the block - is to be lifted and the remaining pinned event references need to be - preserved + Tuple of: + Is the room currently blocked + + The list of pinned event IDs that are unrelated to limit blocking + This list can be used as a convenience in the case where the block + is to be lifted and the remaining pinned event references need to be + preserved """ currently_blocked = False pinned_state_event = None diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 48976dc570..33ffef521b 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -204,9 +204,8 @@ class _EventPeristenceQueue(Generic[_PersistResult]): process to to so, calling the per_item_callback for each item. Args: - room_id (str): - task (_EventPersistQueueTask): A _PersistEventsTask or - _UpdateCurrentStateTask to process. + room_id: + task: A _PersistEventsTask or _UpdateCurrentStateTask to process. Returns: the result returned by the `_per_item_callback` passed to diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index aa58c2adc3..e114c733d1 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -535,7 +535,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): limit: Maximum number of device updates to return Returns: - List: List of device update tuples: + List of device update tuples: - user_id - device_id - stream_id diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index af59be6b48..6240f9a75e 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -391,10 +391,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): Returns: A dict giving the info metadata for this backup version, with fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - etag(int): tag of the keys in the backup + version (str) + algorithm (str) + auth_data (object): opaque dict supplied by the client + etag (int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2a4f58ed92..cf33e73e2b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -412,10 +412,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Retrieve a number of one-time keys for a user Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve + user_id: id of user to get keys for + device_id: id of device to get keys for + key_ids: list of key ids (excluding algorithm) to retrieve Returns: A map from (algorithm, key_id) to json string for key diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c4acff5be6..d68f127f9b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1279,9 +1279,10 @@ class PersistEventsStore: Pick the earliest non-outlier if there is one, else the earliest one. Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts: + Returns: - list[(EventBase, EventContext)]: filtered list + filtered list """ new_events_and_contexts: OrderedDict[ str, Tuple[EventBase, EventContext] @@ -1307,9 +1308,8 @@ class PersistEventsStore: """Update min_depth for each room Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: @@ -1580,13 +1580,11 @@ class PersistEventsStore: """Update all the miscellaneous tables for new events Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + txn: db connection + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 467d20253d..8a104f7e93 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1589,7 +1589,7 @@ class EventsWorkerStore(SQLBaseStore): room_id: The room ID to query. Returns: - dict[str:float] of complexity version to complexity. + Map of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index efd136a864..db9a24db5e 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -217,7 +217,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: - reserved_users (tuple): reserved users to preserve + reserved_users: reserved users to preserve """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) @@ -370,8 +370,8 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): should not appear in the MAU stats). Args: - txn (cursor): - user_id (str): user to add/update + txn: + user_id: user to add/update """ assert ( self._update_on_this_worker @@ -401,7 +401,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): add the user to the monthly active tables Args: - user_id(str): the user_id to query + user_id: the user_id to query """ assert ( self._update_on_this_worker diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 5167089e03..31f0f2bd3d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -953,7 +953,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Returns user id from threepid Args: - txn (cursor): + txn: medium: threepid medium e.g. email address: threepid address e.g. me@example.com @@ -1283,8 +1283,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Sets an expiration date to the account with the given user ID. Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be + user_id: User ID to set an expiration date for. + use_delta: If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d97f8f60e..4fbaefad73 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2057,7 +2057,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): Args: report_id: ID of reported event in database Returns: - event_report: json list of information from event report + JSON dict of information from an event report or None if the + report does not exist. """ def _get_event_report_txn( @@ -2130,8 +2131,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: - event_reports: json list of event reports - count: total number of event reports matching the filter criteria + Tuple of: + json list of event reports + total number of event reports matching the filter criteria """ def _get_event_reports_paginate_txn( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ddb25b5cea..698d6f7515 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -185,9 +185,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - who should be in the user_directory. Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. + progress + batch_size: Maximum number of state events to process per cycle. Returns: number of events processed. @@ -708,10 +707,10 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns the rooms that a user is in. Args: - user_id(str): Must be a local user + user_id: Must be a local user Returns: - list: user_id + List of room IDs """ rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", diff --git a/synapse/types.py b/synapse/types.py index 773f0438d5..f2d436ddc3 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -143,8 +143,8 @@ class Requester: Requester. Args: - store (DataStore): Used to convert AS ID to AS object - input (dict): A dict produced by `serialize` + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` Returns: Requester diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7f1d41eb3c..d24c4f68c4 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -217,7 +217,8 @@ async def concurrently_execute( limit: Maximum number of conccurent executions. Returns: - Deferred: Resolved when all function invocations have finished. + None, when all function invocations have finished. The return values + from those functions are discarded. """ it = iter(args) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f7c3a6794e..9387632d0d 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -197,7 +197,7 @@ def register_cache( resize_callback: A function which can be called to resize the cache. Returns: - CacheMetric: an object which provides inc_{hits,misses,evictions} methods + an object which provides inc_{hits,misses,evictions} methods """ if resizable: if not resize_callback: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bcb1cba362..bf7bd351e0 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -153,7 +153,7 @@ class DeferredCache(Generic[KT, VT]): Args: key: callback: Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics + update_metrics: whether to update the cache hit rate metrics Returns: A Deferred which completes with the result. Note that this may later fail diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index fa91479c97..5eaf70c7ab 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -169,10 +169,11 @@ class DictionaryCache(Generic[KT, DKT, DV]): if it is in the cache. Returns: - DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry` - will contain include the keys that are in the cache. If None then - will either return the full dict if in the cache, or the empty - dict (with `full` set to False) if it isn't. + If `dict_keys` is not None then `DictionaryEntry` will contain include + the keys that are in the cache. + + If None then will either return the full dict if in the cache, or the + empty dict (with `full` set to False) if it isn't. """ if dict_keys is None: # The caller wants the full set of dictionary keys for this cache key diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c6a5d0dfc0..01ad02af67 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -207,7 +207,7 @@ class ExpiringCache(Generic[KT, VT]): items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ new_size = int(self._original_max_size * factor) if new_size != self._max_size: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index aa93109d13..dcf0eac3bf 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -389,11 +389,11 @@ class LruCache(Generic[KT, VT]): cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - cache_type (type): + cache_type: type of underlying cache to be used. Typically one of dict or TreeCache. - size_callback (func(V) -> int | None): + size_callback: metrics_collection_callback: metrics collection callback. This is called early in the metrics @@ -403,7 +403,7 @@ class LruCache(Generic[KT, VT]): Ignored if cache_name is None. - apply_cache_factor_from_config (bool): If true, `max_size` will be + apply_cache_factor_from_config: If true, `max_size` will be multiplied by a cache factor derived from the homeserver config clock: @@ -796,7 +796,7 @@ class LruCache(Generic[KT, VT]): items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ if not self.apply_cache_factor_from_config: return False diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 9f64fed0d7..2aceb1a47f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -183,7 +183,7 @@ class FederationRateLimiter: # Handle request ... Args: - host (str): Origin of incoming request. + host: Origin of incoming request. Returns: context manager which returns a deferred. diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 1e9c2faa64..54bc7589fd 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -48,7 +48,7 @@ async def check_3pid_allowed( registration: whether we want to bind the 3PID as part of registering a new user. Returns: - bool: whether the 3PID medium/address is allowed to be added to this HS + whether the 3PID medium/address is allowed to be added to this HS """ if not await hs.get_password_auth_provider().is_3pid_allowed( medium, address, registration diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 177e198e7e..b1ec7f4bd8 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -90,10 +90,10 @@ class WheelTimer(Generic[T]): """Fetch any objects that have timed out Args: - now (ms): Current time in msec + now: Current time in msec Returns: - list: List of objects that have timed out + List of objects that have timed out """ now_key = int(now / self.bucket_size) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index e74f7f5b48..093537adef 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import os.path import subprocess +from typing import List from zope.interface import implementer @@ -70,14 +71,14 @@ subjectAltName = %(sanentries)s """ -def create_test_cert_file(sanlist): +def create_test_cert_file(sanlist: List[bytes]) -> str: """build an x509 certificate file Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert Returns: - str: the path to the file + The path to the file """ global cert_file_count csr_filename = "server.csr" diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 96f3880923..dce71f7334 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -143,6 +143,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + assert event.internal_metadata.stream_ordering is not None self.replicate() @@ -230,6 +231,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): j2 = self.persist( type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) + assert j2.internal_metadata.stream_ordering is not None self.replicate() expected_pos = PersistedEventPosition( @@ -287,6 +289,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): ) ) self.replicate() + assert j2.internal_metadata.stream_ordering is not None event_source = RoomEventSource(self.hs) event_source.store = self.slaved_store @@ -336,10 +339,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - def persist(self, backfill=False, **kwargs): + def persist(self, backfill=False, **kwargs) -> FrozenEvent: """ Returns: - synapse.events.FrozenEvent: The event that was persisted. + The event that was persisted. """ event, context = self.build_event(**kwargs) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 13aa5eb51a..96cdf2c45b 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -15,8 +15,9 @@ import logging import os from typing import Optional, Tuple +from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from twisted.web.server import Request @@ -102,7 +103,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): ) # fish the test server back out of the server-side TLS protocol. - http_server = server_tls_protocol.wrappedProtocol + http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment] # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) @@ -238,16 +239,15 @@ def get_connection_factory(): return test_server_connection_factory -def _build_test_server(connection_creator): +def _build_test_server( + connection_creator: IOpenSSLServerConnectionCreator, +) -> TLSMemoryBIOProtocol: """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + connection_creator: thing to build SSL connections Returns: TLSMemoryBIOProtocol diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index bf403045e9..7cbc40736c 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -350,14 +351,15 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.assertTrue(notice_in_room, "No server notice in room") - def _trigger_notice_and_join(self): + def _trigger_notice_and_join(self) -> Tuple[str, str, str]: """Creates enough active users to hit the MAU limit and trigger a system notice about it, then joins the system notices room with one of the users created. Returns: - user_id (str): The ID of the user that joined the room. - tok (str): The access token of the user that joined the room. - room_id (str): The ID of the room that's been joined. + A tuple of: + user_id: The ID of the user that joined the room. + tok: The access token of the user that joined the room. + room_id: The ID of the room that's been joined. """ user_id = None tok = None diff --git a/tests/unittest.py b/tests/unittest.py index 5116be338e..a120c2976c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -360,13 +360,13 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock): """ Make and return a homeserver. Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. + clock: The Clock, associated with the reactor. Returns: A homeserver suitable for testing. @@ -426,9 +426,8 @@ class HomeserverTestCase(TestCase): Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. - homeserver (synapse.server.HomeServer): The HomeServer to test - against. + clock: The Clock, associated with the reactor. + homeserver: The HomeServer to test against. Function to optionally be overridden in subclasses. """ @@ -452,11 +451,10 @@ class HomeserverTestCase(TestCase): given content. Args: - method (bytes/unicode): The HTTP request method ("verb"). - path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. - escaped UTF-8 & spaces and such). - content (bytes or dict): The body of the request. JSON-encoded, if - a dict. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces + and such). content (bytes or dict): The body of the request. + JSON-encoded, if a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake -- cgit 1.5.1 From 618e4ab81b70e37bdb8e9224bd84fcfe4b15bdea Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:25:35 +0000 Subject: Fix an invalid comparison of `UserPresenceState` to `str` (#14393) --- changelog.d/14393.bugfix | 1 + synapse/handlers/presence.py | 2 +- tests/handlers/test_presence.py | 41 +++++++++++++++++++++++++++++++++++------ tests/module_api/test_api.py | 3 +++ tests/replication/_base.py | 7 ++++++- 5 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14393.bugfix (limited to 'tests') diff --git a/changelog.d/14393.bugfix b/changelog.d/14393.bugfix new file mode 100644 index 0000000000..97177bc62f --- /dev/null +++ b/changelog.d/14393.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.58.0 where a user with presence state 'org.matrix.msc3026.busy' would mistakenly be set to 'online' when calling `/sync` or `/events` on a worker process. \ No newline at end of file diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b7bc787636..cf08737d11 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -478,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler): return _NullContextManager() prev_state = await self.current_state_for_user(user_id) - if prev_state != PresenceState.BUSY: + if prev_state.state != PresenceState.BUSY: # We set state here but pass ignore_status_msg = True as we don't want to # cause the status message to be cleared. # Note that this causes last_active_ts to be incremented which is not diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c96dc6caf2..c5981ff965 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -15,6 +15,7 @@ from typing import Optional from unittest.mock import Mock, call +from parameterized import parameterized from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState @@ -37,6 +38,7 @@ from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): @@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(state, new_state) -class PresenceHandlerTestCase(unittest.HomeserverTestCase): +class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def prepare(self, reactor, clock, hs): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() @@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): # our status message should be the same as it was before self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_keeps_busy(self): - """Test that presence set by syncing doesn't affect busy status""" - # while this isn't the default - self.presence_handler._busy_presence_enabled = True + @parameterized.expand([(False,), (True,)]) + @unittest.override_config( + { + "experimental_features": { + "msc3026_enabled": True, + }, + } + ) + def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + """Test that presence set by syncing doesn't affect busy status + Args: + test_with_workers: If True, check the presence state of the user by calling + /sync against a worker, rather than the main process. + """ user_id = "@test:server" status_msg = "I'm busy!" + # By default, we call /sync against the main process. + worker_to_sync_against = self.hs + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "presence_writer"} + ) + + # Set presence to BUSY self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) + # Perform a sync with a presence state other than busy. This should NOT change + # our presence status; we only change from busy if we explicitly set it via + # /presence/*. self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + worker_to_sync_against.get_presence_handler().user_syncing( + user_id, True, PresenceState.ONLINE + ) ) + # Check against the main process that the user's presence did not change. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 02cef6f876..058ca57e55 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -778,8 +778,11 @@ def _test_sending_local_online_presence_to_local_user( worker process. The test users will still sync with the main process. The purpose of testing with a worker is to check whether a Synapse module running on a worker can inform other workers/ the main process that they should include additional presence when a user next syncs. + If this argument is True, `test_case` MUST be an instance of BaseMultiWorkerStreamTestCase. """ if test_with_workers: + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) + # Create a worker process to make module_api calls against worker_hs = test_case.make_worker_hs( "synapse.app.generic_worker", {"worker_name": "presence_writer"} diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 121f3d8d65..3029a16dda 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -542,8 +542,13 @@ class FakeRedisPubSubProtocol(Protocol): self.send("OK") elif command == b"GET": self.send(None) + + # Connection keep-alives. + elif command == b"PING": + self.send("PONG") + else: - raise Exception("Unknown command") + raise Exception(f"Unknown command: {command}") def send(self, msg): """Send a message back to the client.""" -- cgit 1.5.1 From 115f0eb2334b13665e5c112bd87f95ea393c9047 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 16 Nov 2022 22:16:46 +0000 Subject: Reintroduce #14376, with bugfix for monoliths (#14468) * Add tests for StreamIdGenerator * Drive-by: annotate all defs * Revert "Revert "Remove slaved id tracker (#14376)" (#14463)" This reverts commit d63814fd736fed5d3d45ff3af5e6d3bfae50c439, which in turn reverted 36097e88c4da51fce6556a58c49bd675f4cf20ab. This restores the latter. * Fix StreamIdGenerator not handling unpersisted IDs Spotted by @erikjohnston. Closes #14456. * Changelog Co-authored-by: Nick Mills-Barrett Co-authored-by: Erik Johnston --- changelog.d/14376.misc | 1 + changelog.d/14468.misc | 1 + mypy.ini | 3 + synapse/replication/slave/__init__.py | 13 -- synapse/replication/slave/storage/__init__.py | 13 -- .../slave/storage/_slaved_id_tracker.py | 50 ------- synapse/storage/databases/main/account_data.py | 30 ++-- synapse/storage/databases/main/devices.py | 36 ++--- synapse/storage/databases/main/events_worker.py | 35 ++--- synapse/storage/databases/main/push_rule.py | 17 +-- synapse/storage/databases/main/pusher.py | 24 ++- synapse/storage/databases/main/receipts.py | 18 +-- synapse/storage/util/id_generators.py | 13 +- tests/storage/test_id_generators.py | 162 +++++++++++++++++++-- 14 files changed, 230 insertions(+), 186 deletions(-) create mode 100644 changelog.d/14376.misc create mode 100644 changelog.d/14468.misc delete mode 100644 synapse/replication/slave/__init__.py delete mode 100644 synapse/replication/slave/storage/__init__.py delete mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'tests') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/changelog.d/14468.misc b/changelog.d/14468.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14468.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/mypy.ini b/mypy.ini index 8f1141a239..53512b2584 100644 --- a/mypy.ini +++ b/mypy.ini @@ -117,6 +117,9 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True +[mypy-tests.storage.test_id_generators] +disallow_untyped_defs = True + [mypy-tests.storage.test_profile] disallow_untyped_defs = True diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e114c733d1..57230df5ae 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8a104f7e93..01e935edef 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index fbf27497ec..a580e4bdda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..0d7108f01b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if not self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 2d8d1f860f..d6a2b8d274 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -16,15 +16,157 @@ from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import IncorrectDatabaseSetup -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS +class StreamIdGeneratorTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") + + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + table="foobar", + column="stream_id", + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("master", 3) # Now we add a row *without* updating the stream ID - def _insert(txn): + def _insert(txn: Cursor) -> None: txn.execute("INSERT INTO foobar VALUES (26, 'master')") self.get_success(self.db_pool.runInteraction("_insert", _insert)) @@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name: str, number: int, update_stream_table: bool = True, - ): + ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ -- cgit 1.5.1 From e1b15f25f3ad4b45b381544ca6b3cd2caf43d25d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Nov 2022 19:56:42 +0000 Subject: Fix /key/v2/server calls with URL-unsafe key IDs (#14490) Co-authored-by: Patrick Cloke --- changelog.d/14490.misc | 1 + synapse/crypto/keyring.py | 2 +- tests/crypto/test_keyring.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14490.misc (limited to 'tests') diff --git a/changelog.d/14490.misc b/changelog.d/14490.misc new file mode 100644 index 0000000000..c0a4daa885 --- /dev/null +++ b/changelog.d/14490.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index dd9b8089ec..ed15f88350 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -857,7 +857,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): response = await self.client.get_json( destination=server_name, path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), + + urllib.parse.quote(requested_key_id, safe=""), ignore_backoff=True, # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 820a1a54e2..63628aa6b0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -469,6 +469,18 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) + def test_keyid_containing_forward_slash(self) -> None: + """We should url-encode any url unsafe chars in key ids. + + Detects https://github.com/matrix-org/synapse/issues/14488. + """ + fetcher = ServerKeyFetcher(self.hs) + self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) + + self.http_client.get_json.assert_called_once() + args, kwargs = self.http_client.get_json.call_args + self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") + class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): -- cgit 1.5.1 From 1526ff389f02d14d0df729bd6ea35836e758c449 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 21 Nov 2022 16:46:14 +0100 Subject: Faster joins: filter out non local events when a room doesn't have its full state (#14404) Signed-off-by: Mathieu Velten --- changelog.d/14404.misc | 1 + synapse/federation/sender/per_destination_queue.py | 1 + synapse/handlers/federation.py | 15 +++++++---- synapse/visibility.py | 29 +++++++++++++++++++--- tests/test_visibility.py | 10 ++++---- 5 files changed, 43 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14404.misc (limited to 'tests') diff --git a/changelog.d/14404.misc b/changelog.d/14404.misc new file mode 100644 index 0000000000..b9ab525f2b --- /dev/null +++ b/changelog.d/14404.misc @@ -0,0 +1 @@ +Faster joins: filter out non local events when a room doesn't have its full state. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 084c45a95c..3ae5e8634c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -505,6 +505,7 @@ class PerDestinationQueue: new_pdus = await filter_events_for_server( self._storage_controllers, self._destination, + self._server_name, new_pdus, redact=False, ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 188f0956ef..d92582fd5c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -379,6 +379,7 @@ class FederationHandler: filtered_extremities = await filter_events_for_server( self._storage_controllers, self.server_name, + self.server_name, events_to_check, redact=False, check_history_visibility_only=True, @@ -1231,7 +1232,9 @@ class FederationHandler: async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Synapse asks for 100 events per backfill request. Do not allow more. limit = min(limit, 100) @@ -1252,7 +1255,7 @@ class FederationHandler: ) events = await filter_events_for_server( - self._storage_controllers, origin, events + self._storage_controllers, origin, self.server_name, events ) return events @@ -1283,7 +1286,7 @@ class FederationHandler: await self._event_auth_handler.assert_host_in_room(event.room_id, origin) events = await filter_events_for_server( - self._storage_controllers, origin, [event] + self._storage_controllers, origin, self.server_name, [event] ) event = events[0] return event @@ -1296,7 +1299,9 @@ class FederationHandler: latest_events: List[str], limit: int, ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) @@ -1309,7 +1314,7 @@ class FederationHandler: ) missing_events = await filter_events_for_server( - self._storage_controllers, origin, missing_events + self._storage_controllers, origin, self.server_name, missing_events ) return missing_events diff --git a/synapse/visibility.py b/synapse/visibility.py index 40a9c5b53f..b443857571 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -563,7 +563,8 @@ def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: async def filter_events_for_server( storage: StorageControllers, - server_name: str, + target_server_name: str, + local_server_name: str, events: List[EventBase], redact: bool = True, check_history_visibility_only: bool = False, @@ -603,7 +604,7 @@ async def filter_events_for_server( # if the server is either in the room or has been invited # into the room. for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == server_name + assert get_domain_from_id(ev.state_key) == target_server_name memtype = ev.membership if memtype == Membership.JOIN: @@ -622,6 +623,24 @@ async def filter_events_for_server( # to no users having been erased. erased_senders = {} + # Filter out non-local events when we are in the middle of a partial join, since our servers + # list can be out of date and we could leak events to servers not in the room anymore. + # This can also be true for local events but we consider it to be an acceptable risk. + + # We do this check as a first step and before retrieving membership events because + # otherwise a room could be fully joined after we retrieve those, which would then bypass + # this check but would base the filtering on an outdated view of the membership events. + + partial_state_invisible_events = set() + if not check_history_visibility_only: + for e in events: + sender_domain = get_domain_from_id(e.sender) + if ( + sender_domain != local_server_name + and await storage.main.is_partial_state_room(e.room_id) + ): + partial_state_invisible_events.add(e) + # Let's check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). @@ -636,7 +655,7 @@ async def filter_events_for_server( if event_to_history_vis[e.event_id] not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) ], - server_name, + target_server_name, ) to_return = [] @@ -645,6 +664,10 @@ async def filter_events_for_server( visible = check_event_is_visible( event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) ) + + if e in partial_state_invisible_events: + visible = False + if visible and not erased: to_return.append(e) elif redact: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index c385b2f8d4..d0b9ad5454 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -61,7 +61,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "hs", events_to_filter ) ) @@ -83,7 +83,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier] + self._storage_controllers, "remote_hs", "hs", [outlier] ) ), [outlier], @@ -94,7 +94,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier, evt] + self._storage_controllers, "remote_hs", "local_hs", [outlier, evt] ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") @@ -106,7 +106,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # be redacted) filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "other_server", [outlier, evt] + self._storage_controllers, "other_server", "local_hs", [outlier, evt] ) ) self.assertEqual(filtered[0], outlier) @@ -141,7 +141,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "local_hs", events_to_filter ) ) -- cgit 1.5.1 From 1799a54a545618782840a60950ef4b64da9ee24d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 07:26:11 -0500 Subject: Batch fetch bundled annotations (#14491) Avoid an n+1 query problem and fetch the bundled aggregations for m.annotation relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660) and threads in b65acead428653b988351ae8d7b22127a22039cd (#11752). --- changelog.d/14491.feature | 1 + synapse/handlers/relations.py | 197 ++++++++++++++++------------ synapse/storage/databases/main/relations.py | 139 ++++++++++++-------- synapse/util/caches/descriptors.py | 2 +- tests/rest/client/test_relations.py | 4 +- 5 files changed, 202 insertions(+), 141 deletions(-) create mode 100644 changelog.d/14491.feature (limited to 'tests') diff --git a/changelog.d/14491.feature b/changelog.d/14491.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14491.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e71dda970..ca94239f61 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,7 +13,16 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, +) import attr @@ -259,48 +268,64 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_event( - self, - event_id: str, - room_id: str, - limit: int = 5, - ignored_users: FrozenSet[str] = frozenset(), - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + async def get_annotations_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[JsonDict]]: + """Get a list of annotations to the given events, grouped by event type and aggregation key, sorted by count. - This is used e.g. to get the what and how many reactions have happend + This is used e.g. to get the what and how many reactions have happened on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. ignored_users: The users ignored by the requesting user. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_event( - event_id, room_id, limit + full_results = await self._main_store.get_aggregation_groups_for_events( + event_ids ) + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in full_results.items() + if results + } + # Then subtract off the results for any ignored users. ignored_results = await self._main_store.get_aggregation_groups_for_users( - event_id, room_id, limit, ignored_users + [event_id for event_id, results in full_results.items() if results], + ignored_users, ) - filtered_results = [] - for result in full_results: - key = (result["type"], result["key"]) - if key in ignored_results: - result = result.copy() - result["count"] -= ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.append(result) + filtered_results = {} + for event_id, results in full_results.items(): + # If no annotations, skip. + if not results: + continue + + # If there are not ignored results for this event, copy verbatim. + if event_id not in ignored_results: + filtered_results[event_id] = results + continue + + # Otherwise, subtract out the ignored results. + event_ignored_results = ignored_results[event_id] + for result in results: + key = (result["type"], result["key"]) + if key in event_ignored_results: + # Ensure to not modify the cache. + result = result.copy() + result["count"] -= event_ignored_results[key] + if result["count"] <= 0: + continue + filtered_results.setdefault(event_id, []).append(result) return filtered_results @@ -366,59 +391,62 @@ class RelationsHandler: results = {} for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event = summary - - # Subtract off the count of any ignored users. - for ignored_user in ignored_users: - thread_count -= ignored_results.get((event_id, ignored_user), 0) - - # This is gnarly, but if the latest event is from an ignored user, - # attempt to find one that isn't from an ignored user. - if latest_thread_event.sender in ignored_users: - room_id = latest_thread_event.room_id - - # If the root event is not found, something went wrong, do - # not include a summary of the thread. - event = await self._event_handler.get_event(user, room_id, event_id) - if event is None: - continue + # If no thread, skip. + if not summary: + continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, - ) + thread_count, latest_thread_event = summary - # If all found events are from ignored users, do not include - # a summary of the thread. - if not potential_events: - continue + # Subtract off the count of any ignored users. + for ignored_user in ignored_users: + thread_count -= ignored_results.get((event_id, ignored_user), 0) - # The *last* event returned is the one that is cared about. - event = await self._event_handler.get_event( - user, room_id, potential_events[-1].event_id - ) - # It is unexpected that the event will not exist. - if event is None: - logger.warning( - "Unable to fetch latest event in a thread with event ID: %s", - potential_events[-1].event_id, - ) - continue - latest_thread_event = event - - results[event_id] = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=events_by_id[event_id].sender == user_id - or participated[event_id], + # This is gnarly, but if the latest event is from an ignored user, + # attempt to find one that isn't from an ignored user. + if latest_thread_event.sender in ignored_users: + room_id = latest_thread_event.room_id + + # If the root event is not found, something went wrong, do + # not include a summary of the thread. + event = await self._event_handler.get_event(user, room_id, event_id) + if event is None: + continue + + potential_events, _ = await self.get_relations_for_event( + event_id, + event, + room_id, + RelationTypes.THREAD, + ignored_users, ) + # If all found events are from ignored users, do not include + # a summary of the thread. + if not potential_events: + continue + + # The *last* event returned is the one that is cared about. + event = await self._event_handler.get_event( + user, room_id, potential_events[-1].event_id + ) + # It is unexpected that the event will not exist. + if event is None: + logger.warning( + "Unable to fetch latest event in a thread with event ID: %s", + potential_events[-1].event_id, + ) + continue + latest_thread_event = event + + results[event_id] = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], + ) + return results @trace @@ -496,17 +524,18 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any annotations (ie, reactions) to bundle with this event. - annotations = await self.get_annotations_for_event( - event.event_id, event.room_id, ignored_users=ignored_users - ) + # Fetch any annotations (ie, reactions) to bundle with this event. + annotations_by_event_id = await self.get_annotations_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, annotations in annotations_by_event_id.items(): if annotations: - results.setdefault( - event.event_id, BundledAggregations() - ).annotations = {"chunk": annotations} + results.setdefault(event_id, BundledAggregations()).annotations = { + "chunk": annotations + } + # Fetch other relations per event. + for event in events_by_id.values(): # Fetch any references to bundle with this event. references, next_token = await self.get_relations_for_event( event.event_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ca431002c8..f96a16956a 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,6 +20,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached(tree=True) - async def get_aggregation_groups_for_event( - self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + @cached() + async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" + ) + async def get_aggregation_groups_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[JsonDict]]]: + """Get a list of annotations on the given events, grouped by event type and aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ + # The number of entries to return per event ID. + limit = 5 - args = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - limit, - ] + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.ANNOTATION) - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + sql = f""" + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE + {clause} + AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_events_txn( txn: LoggingTransaction, - ) -> List[JsonDict]: + ) -> Mapping[str, List[JsonDict]]: txn.execute(sql, args) - return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + result: Dict[str, List[JsonDict]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + event_results = result.setdefault(event_id, []) + + # Limit the number of results per event ID. + if len(event_results) == limit: + continue + + event_results.append({"type": type, "key": key, "count": count}) + + return result return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn ) async def get_aggregation_groups_for_users( - self, - event_id: str, - room_id: str, - limit: int, - users: FrozenSet[str] = frozenset(), - ) -> Dict[Tuple[str, str], int]: + self, event_ids: Collection[str], users: FrozenSet[str] + ) -> Dict[str, Dict[Tuple[str, str], int]]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. users: The users to fetch information for. Returns: - A map of (event type, aggregation key) to a count of users. + A map of event ID to a map of (event type, aggregation key) to a + count of users. """ if not users: return {} - args: List[Union[str, int]] = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - ] + events_sql, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "sender", users + self.database_engine, "annotation.sender", users ) args.extend(users_args) + args.append(RelationTypes.ANNOTATION) sql = f""" - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE {events_sql} AND {users_sql} AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], int]: - txn.execute(sql, args + [limit]) + ) -> Dict[str, Dict[Tuple[str, str], int]]: + txn.execute(sql, args) - return {(row[0], row[1]): row[2] for row in txn} + result: Dict[str, Dict[Tuple[str, str], int]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + result.setdefault(event_id, {})[(type, key)] = count + + return result return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 75428d19ba..72227359b9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -503,7 +503,7 @@ def cachedList( is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in the cache gets passed to the original function, which is expected to results - in a map of key to value for each passed value. THe new results are stored in the + in a map of key to value for each passed value. The new results are stored in the original cache. Note that any missing values are cached as None. Args: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index e3d801f7a8..2d2b683548 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 6d7523ef1484ec56f4a6dffdd2ea3d8736b4cc98 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 09:41:09 -0500 Subject: Batch fetch bundled references (#14508) Avoid an n+1 query problem and fetch the bundled aggregations for m.reference relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660; threads in b65acead428653b988351ae8d7b22127a22039cd (#11752); and annotations in 1799a54a545618782840a60950ef4b64da9ee24d (#14491). --- changelog.d/14508.feature | 1 + synapse/handlers/relations.py | 128 +++++++++++++--------------- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 4 + synapse/storage/databases/main/relations.py | 74 ++++++++++++++-- tests/rest/client/test_relations.py | 4 +- 6 files changed, 133 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14508.feature (limited to 'tests') diff --git a/changelog.d/14508.feature b/changelog.d/14508.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14508.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ca94239f61..8414be5879 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,16 +13,7 @@ # limitations under the License. import enum import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional import attr @@ -32,7 +23,7 @@ from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -181,40 +172,6 @@ class RelationsHandler: return return_value - async def get_relations_for_event( - self, - event_id: str, - event: EventBase, - room_id: str, - relation_type: str, - ignored_users: FrozenSet[str] = frozenset(), - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - """Get a list of events which relate to an event, ordered by topological ordering. - - Args: - event_id: Fetch events that relate to this event ID. - event: The matching EventBase to event_id. - room_id: The room the event belongs to. - relation_type: The type of relation. - ignored_users: The users ignored by the requesting user. - - Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. - """ - - # Call the underlying storage method, which is cached. - related_events, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, relation_type, direction="f" - ) - - # Filter out ignored users and convert to the expected format. - related_events = [ - event for event in related_events if event.sender not in ignored_users - ] - - return related_events, next_token - async def redact_events_related_to( self, requester: Requester, @@ -329,6 +286,46 @@ class RelationsHandler: return filtered_results + async def get_references_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[_RelatedEvent]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to this event ID. + ignored_users: The users ignored by the requesting user. + + Returns: + A map of event IDs to a list related events. + """ + + related_events = await self._main_store.get_references_for_events(event_ids) + + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in related_events.items() + if results + } + + # Filter out ignored users. + results = {} + for event_id, events in related_events.items(): + # If no references, skip. + if not events: + continue + + # Filter ignored users out. + events = [event for event in events if event.sender not in ignored_users] + # If there are no events left, skip this event. + if not events: + continue + + results[event_id] = events + + return results + async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], @@ -412,14 +409,18 @@ class RelationsHandler: if event is None: continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, + # Attempt to find another event to use as the latest event. + potential_events, _ = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.THREAD, direction="f" ) + # Filter out ignored users. + potential_events = [ + event + for event in potential_events + if event.sender not in ignored_users + ] + # If all found events are from ignored users, do not include # a summary of the thread. if not potential_events: @@ -534,27 +535,16 @@ class RelationsHandler: "chunk": annotations } - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any references to bundle with this event. - references, next_token = await self.get_relations_for_event( - event.event_id, - event, - event.room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, - ) + # Fetch any references to bundle with this event. + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): if references: - aggregations = results.setdefault(event.event_id, BundledAggregations()) - aggregations.references = { + results.setdefault(event_id, BundledAggregations()).references = { "chunk": [{"event_id": ev.event_id} for ev in references] } - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - # Fetch any edits (but not for redacted events). # # Note that there is no use in limiting edits by ignored users since the @@ -600,7 +590,7 @@ class RelationsHandler: room_id, requester, allow_departed_users=True ) - # Note that ignored users are not passed into get_relations_for_event + # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). thread_roots, next_batch = await self._main_store.get_threads( diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ddb7397714..a58668a380 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d68f127f9b..0f097a2927 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2049,6 +2049,10 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) + ) if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index f96a16956a..aea96e9d24 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -82,8 +82,6 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore): txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore): # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -530,6 +534,60 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) + @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2d2b683548..b86f341ff5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 9cae44f49e6bf4f6b8a20ab11a65da417bb1565f Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 22 Nov 2022 16:46:52 +0000 Subject: Track unconverted device list outbound pokes using a position instead (#14516) When a local device list change is added to `device_lists_changes_in_room`, the `converted_to_destinations` flag is set to `FALSE` and the `_handle_new_device_update_async` background process is started. This background process looks for unconverted rows in `device_lists_changes_in_room`, copies them to `device_lists_outbound_pokes` and updates the flag. To update the `converted_to_destinations` flag, the database performs a `DELETE` and `INSERT` internally, which fragments the table. To avoid this, track unconverted rows using a `(stream ID, room ID)` position instead of the flag. From now on, the `converted_to_destinations` column indicates rows that need converting to outbound pokes, but does not indicate whether the conversion has already taken place. Closes #14037. Signed-off-by: Sean Quah --- changelog.d/14516.misc | 1 + synapse/handlers/device.py | 30 +++++- synapse/storage/database.py | 13 +-- synapse/storage/databases/main/devices.py | 107 +++++++++++++-------- .../73/12refactor_device_list_outbound_pokes.sql | 53 ++++++++++ tests/storage/test_devices.py | 3 +- 6 files changed, 158 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14516.misc create mode 100644 synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql (limited to 'tests') diff --git a/changelog.d/14516.misc b/changelog.d/14516.misc new file mode 100644 index 0000000000..51666c6ffc --- /dev/null +++ b/changelog.d/14516.misc @@ -0,0 +1 @@ +Refactor conversion of device list changes in room to outbound pokes to track unconverted rows using a `(stream ID, room ID)` position instead of updating the `converted_to_destinations` flag on every row. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c597639a7f..da3ddafeae 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -682,13 +682,33 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to: Set[str] = set() try: + stream_id, room_id = await self.store.get_device_change_last_converted_pos() + while True: self._handle_new_device_update_new_data = False - rows = await self.store.get_uncoverted_outbound_room_pokes() + max_stream_id = self.store.get_device_stream_token() + rows = await self.store.get_uncoverted_outbound_room_pokes( + stream_id, room_id + ) if not rows: # If the DB returned nothing then there is nothing left to # do, *unless* a new device list update happened during the # DB query. + + # Advance `(stream_id, room_id)`. + # `max_stream_id` comes from *before* the query for unconverted + # rows, which means that any unconverted rows must have a larger + # stream ID. + if max_stream_id > stream_id: + stream_id, room_id = max_stream_id, "" + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + else: + assert max_stream_id == stream_id + # Avoid moving `room_id` backwards. + pass + if self._handle_new_device_update_new_data: continue else: @@ -718,7 +738,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=stream_id, hosts=hosts, context=opentracing_context, ) @@ -752,6 +771,12 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to.update(hosts) current_stream_id = stream_id + # Advance `(stream_id, room_id)`. + _, _, room_id, stream_id, _ = rows[-1] + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + finally: self._handle_new_device_update_is_processing = False @@ -834,7 +859,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=None, hosts=potentially_changed_hosts, context=None, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0dc44b246c..a14b13aec8 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2075,13 +2075,14 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) + select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + + if keyvalues: + select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),) + txn.execute(select_sql, list(keyvalues.values())) + else: + txn.execute(select_sql) - txn.execute(select_sql, list(keyvalues.values())) row = txn.fetchone() if not row: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 57230df5ae..37629115ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -2008,27 +2008,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def get_uncoverted_outbound_room_pokes( - self, limit: int = 10 + self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. + Args: + start_stream_id: Together with `start_room_id`, indicates the position after + which to return device list changes. + start_room_id: Together with `start_stream_id`, indicates the position after + which to return device list changes. + limit: The maximum number of device list changes to return. + Returns: - A list of user ID, device ID, room ID, stream ID and optional opentracing context. + A list of user ID, device ID, room ID, stream ID and optional opentracing + context, in order of ascending (stream ID, room ID). """ sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context FROM device_lists_changes_in_room - WHERE NOT converted_to_destinations - ORDER BY stream_id + WHERE + (stream_id, room_id) > (?, ?) AND + stream_id <= ? AND + NOT converted_to_destinations + ORDER BY stream_id ASC, room_id ASC LIMIT ? """ def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: - txn.execute(sql, (limit,)) + txn.execute( + sql, + ( + start_stream_id, + start_room_id, + # Avoid returning rows if there may be uncommitted device list + # changes with smaller stream IDs. + self._device_list_id_gen.get_current_token(), + limit, + ), + ) return [ ( @@ -2050,49 +2071,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - - Marks the associated row in `device_lists_changes_in_room` as handled, - if `stream_id` is provided. """ + if not hosts: + return def add_device_list_outbound_pokes_txn( txn: LoggingTransaction, stream_ids: List[int] ) -> None: - if hosts: - self._add_device_outbound_poke_to_stream_txn( - txn, - user_id=user_id, - device_id=device_id, - hosts=hosts, - stream_ids=stream_ids, - context=context, - ) - - if stream_id: - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) - - if not hosts: - # If there are no hosts then we don't try and generate stream IDs. - return await self.db_pool.runInteraction( - "add_device_list_outbound_pokes", - add_device_list_outbound_pokes_txn, - [], + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_id=device_id, + hosts=hosts, + stream_ids=stream_ids, + context=context, ) async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: @@ -2156,3 +2153,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "get_pending_remote_device_list_updates_for_room", get_pending_remote_device_list_updates_for_room_txn, ) + + async def get_device_change_last_converted_pos(self) -> Tuple[int, str]: + """ + Get the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + + Rows with a strictly greater position where `converted_to_destinations` is + `FALSE` have not been converted. + """ + + row = await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ) + return row["stream_id"], row["room_id"] + + async def set_device_change_last_converted_pos( + self, + stream_id: int, + room_id: str, + ) -> None: + """ + Set the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + """ + + await self.db_pool.simple_update_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + updatevalues={"stream_id": stream_id, "room_id": room_id}, + desc="set_device_change_last_converted_pos", + ) diff --git a/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql new file mode 100644 index 0000000000..93d7fcb79b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql @@ -0,0 +1,53 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Prior to this schema delta, we tracked the set of unconverted rows in +-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows +-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag +-- would be set. +-- +-- After this schema delta, the `converted_to_destinations` is still populated like +-- before, but the set of unconverted rows is determined by the `stream_id` in the new +-- `device_lists_changes_converted_stream_position` table. +-- +-- If rolled back, Synapse will re-send all device list changes that happened since the +-- schema delta. + +CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + -- The (stream id, room id) of the last row in `device_lists_changes_in_room` that + -- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger + -- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been + -- converted. + stream_id BIGINT NOT NULL, + -- `room_id` may be an empty string, which compares less than all valid room IDs. + room_id TEXT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES ( + ( + SELECT COALESCE( + -- The last converted stream id is the smallest unconverted stream id minus + -- one. + MIN(stream_id) - 1, + -- If there is no unconverted stream id, the last converted stream id is the + -- largest stream id. + -- Otherwise, pick 1, since stream ids start at 2. + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room) + ) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations + ), + '' +); diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f37505b6cf..8e7db2c4ec 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(HomeserverTestCase): """ for device_id in device_ids: - stream_id = self.get_success( + self.get_success( self.store.add_device_change_to_streams( user_id, [device_id], ["!some:room"] ) @@ -39,7 +39,6 @@ class DeviceStoreTestCase(HomeserverTestCase): user_id=user_id, device_id=device_id, room_id="!some:room", - stream_id=stream_id, hosts=[host], context={}, ) -- cgit 1.5.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'tests') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1 From 4ae967cf6308e80b03da749f0cbaed36988e235e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 17:35:54 -0500 Subject: Add missing type hints to test.util.caches (#14529) --- changelog.d/14529.misc | 1 + mypy.ini | 11 +++--- tests/util/caches/test_cached_call.py | 23 ++++++------ tests/util/caches/test_deferred_cache.py | 61 ++++++++++++++++---------------- tests/util/caches/test_descriptors.py | 22 +++++++----- tests/util/caches/test_response_cache.py | 16 ++++----- tests/util/caches/test_ttlcache.py | 8 ++--- 7 files changed, 76 insertions(+), 66 deletions(-) create mode 100644 changelog.d/14529.misc (limited to 'tests') diff --git a/changelog.d/14529.misc b/changelog.d/14529.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14529.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 4cd61e0484..25b3c93748 100644 --- a/mypy.ini +++ b/mypy.ini @@ -59,11 +59,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/test_state.py |tests/test_terms_auth.py - |tests/util/caches/test_cached_call.py - |tests/util/caches/test_deferred_cache.py - |tests/util/caches/test_descriptors.py - |tests/util/caches/test_response_cache.py - |tests/util/caches/test_ttlcache.py |tests/util/test_async_helpers.py |tests/util/test_batching_queue.py |tests/util/test_dict_cache.py @@ -133,6 +128,12 @@ disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True +[mypy-tests.util.caches.*] +disallow_untyped_defs = True + +[mypy-tests.util.caches.test_descriptors] +disallow_untyped_defs = False + [mypy-tests.utils] disallow_untyped_defs = True diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py index 80b97167ba..9266f12590 100644 --- a/tests/util/caches/test_cached_call.py +++ b/tests/util/caches/test_cached_call.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import NoReturn from unittest.mock import Mock from twisted.internet import defer @@ -23,14 +24,14 @@ from tests.unittest import TestCase class CachedCallTestCase(TestCase): - def test_get(self): + def test_get(self) -> None: """ Happy-path test case: makes a couple of calls and makes sure they behave correctly """ - d = Deferred() + d: "Deferred[int]" = Deferred() - async def f(): + async def f() -> int: return await d slow_call = Mock(side_effect=f) @@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase): # now fire off a couple of calls completed_results = [] - async def r(): + async def r() -> None: res = await cached_call.get() completed_results.append(res) @@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase): self.assertEqual(r3, 123) slow_call.assert_not_called() - def test_fast_call(self): + def test_fast_call(self) -> None: """ Test the behaviour when the underlying function completes immediately """ - async def f(): + async def f() -> int: return 12 fast_call = Mock(side_effect=f) @@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase): class RetryOnExceptionCachedCallTestCase(TestCase): - def test_get(self): + def test_get(self) -> None: # set up the RetryOnExceptionCachedCall around a function which will fail # (after a while) - d = Deferred() + d: "Deferred[int]" = Deferred() - async def f1(): + async def f1() -> NoReturn: await d raise ValueError("moo") @@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase): # now fire off a couple of calls completed_results = [] - async def r(): + async def r() -> None: try: await cached_call.get() except Exception as e1: @@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase): # to the getter d = Deferred() - async def f2(): + async def f2() -> int: return await d slow_call.reset_mock() diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 02b99b466a..f74d82b1dc 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +from typing import List, Tuple from twisted.internet import defer @@ -22,20 +23,20 @@ from tests.unittest import TestCase class DeferredCacheTestCase(TestCase): - def test_empty(self): - cache = DeferredCache("test") + def test_empty(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") with self.assertRaises(KeyError): cache.get("foo") - def test_hit(self): - cache = DeferredCache("test") + def test_hit(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") cache.prefill("foo", 123) self.assertEqual(self.successResultOf(cache.get("foo")), 123) - def test_hit_deferred(self): - cache = DeferredCache("test") - origin_d = defer.Deferred() + def test_hit_deferred(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") + origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d) # get should return an incomplete deferred @@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase): self.assertFalse(get_d.called) # add a callback that will make sure that the set_d gets called before the get_d - def check1(r): + def check1(r: str) -> str: self.assertTrue(set_d.called) return r @@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase): self.assertEqual(self.successResultOf(set_d), 99) self.assertEqual(self.successResultOf(get_d), 99) - def test_callbacks(self): + def test_callbacks(self) -> None: """Invalidation callbacks are called at the right time""" - cache = DeferredCache("test") + cache: DeferredCache[str, int] = DeferredCache("test") callbacks = set() # start with an entry, with a callback cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) # now replace that entry with a pending result - origin_d = defer.Deferred() + origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) # ... and also make a get request @@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase): cache.prefill("k1", 30) self.assertEqual(callbacks, {"set", "get"}) - def test_set_fail(self): - cache = DeferredCache("test") + def test_set_fail(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") callbacks = set() # start with an entry, with a callback cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) # now replace that entry with a pending result - origin_d = defer.Deferred() + origin_d: defer.Deferred = defer.Deferred() set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) # ... and also make a get request @@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase): cache.prefill("k1", 30) self.assertEqual(callbacks, {"prefill", "get2"}) - def test_get_immediate(self): - cache = DeferredCache("test") - d1 = defer.Deferred() + def test_get_immediate(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") + d1: "defer.Deferred[int]" = defer.Deferred() cache.set("key1", d1) # get_immediate should return default @@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase): v = cache.get_immediate("key1", 1) self.assertEqual(v, 2) - def test_invalidate(self): - cache = DeferredCache("test") + def test_invalidate(self) -> None: + cache: DeferredCache[Tuple[str], int] = DeferredCache("test") cache.prefill(("foo",), 123) cache.invalidate(("foo",)) with self.assertRaises(KeyError): cache.get(("foo",)) - def test_invalidate_all(self): - cache = DeferredCache("testcache") + def test_invalidate_all(self) -> None: + cache: DeferredCache[str, str] = DeferredCache("testcache") callback_record = [False, False] - def record_callback(idx): + def record_callback(idx: int) -> None: callback_record[idx] = True # add a couple of pending entries - d1 = defer.Deferred() + d1: "defer.Deferred[str]" = defer.Deferred() cache.set("key1", d1, partial(record_callback, 0)) - d2 = defer.Deferred() + d2: "defer.Deferred[str]" = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) # lookup should return pending deferreds @@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase): with self.assertRaises(KeyError): cache.get("key1", None) - def test_eviction(self): - cache = DeferredCache( + def test_eviction(self) -> None: + cache: DeferredCache[int, str] = DeferredCache( "test", max_entries=2, apply_cache_factor_from_config=False ) @@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase): cache.get(2) cache.get(3) - def test_eviction_lru(self): - cache = DeferredCache( + def test_eviction_lru(self) -> None: + cache: DeferredCache[int, str] = DeferredCache( "test", max_entries=2, apply_cache_factor_from_config=False ) @@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase): cache.get(1) cache.get(3) - def test_eviction_iterable(self): - cache = DeferredCache( + def test_eviction_iterable(self) -> None: + cache: DeferredCache[int, List[str]] = DeferredCache( "test", max_entries=3, apply_cache_factor_from_config=False, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 43475a307f..13f1edd533 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable, Set, Tuple +from typing import Iterable, Set, Tuple, cast from unittest import mock from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.interfaces import IReactorTime from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -37,8 +38,8 @@ logger = logging.getLogger(__name__) def run_on_reactor(): - d = defer.Deferred() - reactor.callLater(0, d.callback, 0) + d: "Deferred[int]" = defer.Deferred() + cast(IReactorTime, reactor).callLater(0, d.callback, 0) return make_deferred_yieldable(d) @@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase): callbacks: Set[str] = set() # set off an asynchronous request - obj.result = origin_d = defer.Deferred() + origin_d: Deferred = defer.Deferred() + obj.result = origin_d d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) self.assertFalse(d1.called) @@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase): """Check that logcontexts are set and restored correctly when using the cache.""" - complete_lookup = defer.Deferred() + complete_lookup: Deferred = defer.Deferred() class Cls: @descriptors.cached() @@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): - assert current_context().name == "c1" + context = current_context() + assert isinstance(context, LoggingContext) + assert context.name == "c1" # we want this to behave like an asynchronous function await run_on_reactor() - assert current_context().name == "c1" + context = current_context() + assert isinstance(context, LoggingContext) + assert context.name == "c1" return self.mock(args1, arg2) with LoggingContext("c1") as c1: @@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): return self.mock(args1) obj = Cls() - deferred_result = Deferred() + deferred_result: "Deferred[dict]" = Deferred() obj.mock.return_value = deferred_result # start off several concurrent lookups of the same key diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index 025b73e32f..f09eeecada 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase): (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) """ - def setUp(self): + def setUp(self) -> None: self.reactor, self.clock = get_clock() def with_cache(self, name: str, ms: int = 0) -> ResponseCache: @@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase): await self.clock.sleep(1) return o - def test_cache_hit(self): + def test_cache_hit(self) -> None: cache = self.with_cache("keeping_cache", ms=9001) expected_result = "howdy" @@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase): "cache should still have the result", ) - def test_cache_miss(self): + def test_cache_miss(self) -> None: cache = self.with_cache("trashing_cache", ms=0) expected_result = "howdy" @@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase): ) self.assertCountEqual([], cache.keys(), "cache should not have the result now") - def test_cache_expire(self): + def test_cache_expire(self) -> None: cache = self.with_cache("short_cache", ms=1000) expected_result = "howdy" @@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase): self.reactor.pump((2,)) self.assertCountEqual([], cache.keys(), "cache should not have the result now") - def test_cache_wait_hit(self): + def test_cache_wait_hit(self) -> None: cache = self.with_cache("neutral_cache") expected_result = "howdy" @@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase): self.assertEqual(expected_result, self.successResultOf(wrap_d)) - def test_cache_wait_expire(self): + def test_cache_wait_expire(self) -> None: cache = self.with_cache("medium_cache", ms=3000) expected_result = "howdy" @@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase): self.assertCountEqual([], cache.keys(), "cache should not have the result now") @parameterized.expand([(True,), (False,)]) - def test_cache_context_nocache(self, should_cache: bool): + def test_cache_context_nocache(self, should_cache: bool) -> None: """If the callback clears the should_cache bit, the result should not be cached""" cache = self.with_cache("medium_cache", ms=3000) @@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase): call_count = 0 - async def non_caching(o: str, cache_context: ResponseCacheContext[int]): + async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str: nonlocal call_count call_count += 1 await self.clock.sleep(1) diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index fe8314057d..679d1eb36b 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -20,11 +20,11 @@ from tests import unittest class CacheTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.mock_timer = Mock(side_effect=lambda: 100.0) - self.cache = TTLCache("test_cache", self.mock_timer) + self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer) - def test_get(self): + def test_get(self) -> None: """simple set/get tests""" self.cache.set("one", "1", 10) self.cache.set("two", "2", 20) @@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase): self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.misses, 5) - def test_expiry(self): + def test_expiry(self) -> None: self.cache.set("one", "1", 10) self.cache.set("two", "2", 20) self.cache.set("three", "3", 30) -- cgit 1.5.1 From 9af2be192a759c22d189b72cc0a7580cd9de8a37 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 24 Nov 2022 09:09:17 +0000 Subject: Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. (#14538) --- changelog.d/14538.removal | 1 + docs/upgrade.md | 22 ++ docs/usage/configuration/config_documentation.md | 25 -- synapse/app/_base.py | 16 +- synapse/app/generic_worker.py | 1 - synapse/app/homeserver.py | 1 - synapse/config/metrics.py | 2 - synapse/metrics/__init__.py | 7 +- synapse/metrics/_legacy_exposition.py | 288 ----------------------- synapse/metrics/_twisted_exposition.py | 38 +++ tests/storage/test_event_metrics.py | 7 +- 11 files changed, 70 insertions(+), 338 deletions(-) create mode 100644 changelog.d/14538.removal delete mode 100644 synapse/metrics/_legacy_exposition.py create mode 100644 synapse/metrics/_twisted_exposition.py (limited to 'tests') diff --git a/changelog.d/14538.removal b/changelog.d/14538.removal new file mode 100644 index 0000000000..d2035ce82a --- /dev/null +++ b/changelog.d/14538.removal @@ -0,0 +1 @@ +Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 2aa353e496..4fe9e4f02e 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,28 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.73.0 + +## Legacy Prometheus metric names have now been removed + +Synapse v1.69.0 included the deprecation of legacy Prometheus metric names +and offered an option to disable them. +Synapse v1.71.0 disabled legacy Prometheus metric names by default. + +This version, v1.73.0, removes those legacy Prometheus metric names entirely. +This also means that the `enable_legacy_metrics` configuration option has been +removed; it will no longer be possible to re-enable the legacy metric names. + +If you use metrics and have not yet updated your Grafana dashboard(s), +Prometheus console(s) or alerting rule(s), please consider doing so when upgrading +to this version. +Note that the included Grafana dashboard was updated in v1.72.0 to correct some +metric names which were missed when legacy metrics were disabled by default. + +See [v1.69.0: Deprecation of legacy Prometheus metric names](#deprecation-of-legacy-prometheus-metric-names) +for more context. + + # Upgrading to v1.72.0 ## Dropping support for PostgreSQL 10 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index f5937dd902..fae2771fad 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2437,31 +2437,6 @@ Example configuration: enable_metrics: true ``` --- -### `enable_legacy_metrics` - -Set to `true` to publish both legacy and non-legacy Prometheus metric names, -or to `false` to only publish non-legacy Prometheus metric names. -Defaults to `false`. Has no effect if `enable_metrics` is `false`. -**In Synapse v1.67.0 up to and including Synapse v1.70.1, this defaulted to `true`.** - -Legacy metric names include: -- metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; -- counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. - -These legacy metric names are unconventional and not compliant with OpenMetrics standards. -They are included for backwards compatibility. - -Example configuration: -```yaml -enable_legacy_metrics: false -``` - -See https://github.com/matrix-org/synapse/issues/11106 for context. - -*Since v1.67.0.* - -**Will be removed in v1.73.0.** ---- ### `sentry` Use this option to enable sentry integration. Provide the DSN assigned to you by sentry diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 41d2732ef9..a5aa2185a2 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -266,26 +266,18 @@ def register_start( reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics( - bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool -) -> None: +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ from prometheus_client import start_http_server as start_http_server_prometheus - from synapse.metrics import ( - RegistryProxy, - start_http_server as start_http_server_legacy, - ) + from synapse.metrics import RegistryProxy for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - if enable_legacy_metric_names: - start_http_server_legacy(port, addr=host, registry=RegistryProxy) - else: - _set_prometheus_client_use_created_metrics(False) - start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + _set_prometheus_client_use_created_metrics(False) + start_http_server_prometheus(port, addr=host, registry=RegistryProxy) def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 74909b7d4a..46dc731696 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -320,7 +320,6 @@ class GenericWorkerServer(HomeServer): _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: logger.warning("Unsupported listener type: %s", listener.type) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 4f4fee4782..b9be558c7e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -265,7 +265,6 @@ class SynapseHomeServer(HomeServer): _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: # this shouldn't happen, as the listener type should have been checked diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 6034a0346e..8c1c9bd12d 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -43,8 +43,6 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) - self.enable_legacy_metrics = config.get("enable_legacy_metrics", False) - self.report_stats = config.get("report_stats", None) self.report_stats_endpoint = config.get( "report_stats_endpoint", "https://matrix.org/report-usage-stats/push" diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index c3d3daf877..b01372565d 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -47,11 +47,7 @@ from twisted.python.threadpool import ThreadPool # This module is imported for its side effects; flake8 needn't warn that it's unused. import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager -from synapse.metrics._legacy_exposition import ( - MetricsResource, - generate_latest, - start_http_server, -) +from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._types import Collector from synapse.util import SYNAPSE_VERSION @@ -474,7 +470,6 @@ __all__ = [ "Collector", "MetricsResource", "generate_latest", - "start_http_server", "LaterGauge", "InFlightGauge", "GaugeBucketCollector", diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py deleted file mode 100644 index 1459f9d224..0000000000 --- a/synapse/metrics/_legacy_exposition.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2015-2019 Prometheus Python Client Developers -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This code is based off `prometheus_client/exposition.py` from version 0.7.1. - -Due to the renaming of metrics in prometheus_client 0.4.0, this customised -vendoring of the code will emit both the old versions that Synapse dashboards -expect, and the newer "best practice" version of the up-to-date official client. -""" -import logging -import math -import threading -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn -from typing import Any, Dict, List, Type, Union -from urllib.parse import parse_qs, urlparse - -from prometheus_client import REGISTRY, CollectorRegistry -from prometheus_client.core import Sample - -from twisted.web.resource import Resource -from twisted.web.server import Request - -logger = logging.getLogger(__name__) -CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" - - -def floatToGoString(d: Union[int, float]) -> str: - d = float(d) - if d == math.inf: - return "+Inf" - elif d == -math.inf: - return "-Inf" - elif math.isnan(d): - return "NaN" - else: - s = repr(d) - dot = s.find(".") - # Go switches to exponents sooner than Python. - # We only need to care about positive values for le/quantile. - if d > 0 and dot > 6: - mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") - return f"{mantissa}e+0{dot - 1}" - return s - - -def sample_line(line: Sample, name: str) -> str: - if line.labels: - labelstr = "{{{0}}}".format( - ",".join( - [ - '{}="{}"'.format( - k, - v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), - ) - for k, v in sorted(line.labels.items()) - ] - ) - ) - else: - labelstr = "" - timestamp = "" - if line.timestamp is not None: - # Convert to milliseconds. - timestamp = f" {int(float(line.timestamp) * 1000):d}" - return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) - - -# Mapping from new metric names to legacy metric names. -# We translate these back to their old names when exposing them through our -# legacy vendored exporter. -# Only this legacy exposition module applies these name changes. -LEGACY_METRIC_NAMES = { - "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits", - "synapse_util_caches_cache_size": "synapse_util_caches_cache:size", - "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size", - "synapse_util_caches_cache": "synapse_util_caches_cache:total", - "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size", - "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits", - "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size", - "synapse_util_caches_response_cache": "synapse_util_caches_response_cache:total", - "synapse_federation_client_sent_pdu_destinations": "synapse_federation_client_sent_pdu_destinations:total", - "synapse_federation_client_sent_pdu_destinations_count": "synapse_federation_client_sent_pdu_destinations:count", - "synapse_admin_mau_current": "synapse_admin_mau:current", - "synapse_admin_mau_max": "synapse_admin_mau:max", - "synapse_admin_mau_registered_reserved_users": "synapse_admin_mau:registered_reserved_users", -} - - -def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: - """ - Generate metrics in legacy format. Modern metrics are generated directly - by prometheus-client. - """ - - output = [] - - for metric in registry.collect(): - if not metric.samples: - # No samples, don't bother. - continue - - # Translate to legacy metric name if it has one. - mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name) - mnewname = metric.name - mtype = metric.type - - # OpenMetrics -> Prometheus - if mtype == "counter": - mnewname = mnewname + "_total" - elif mtype == "info": - mtype = "gauge" - mnewname = mnewname + "_info" - elif mtype == "stateset": - mtype = "gauge" - elif mtype == "gaugehistogram": - mtype = "histogram" - elif mtype == "unknown": - mtype = "untyped" - - # Output in the old format for compatibility. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname} {mtype}\n") - - om_samples: Dict[str, List[str]] = {} - for s in metric.samples: - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - # OpenMetrics specific sample, put in a gauge at the end. - # (these come from gaugehistograms which don't get renamed, - # so no need to faff with mnewname) - om_samples.setdefault(suffix, []).append(sample_line(s, s.name)) - break - else: - newname = s.name.replace(mnewname, mname) - if ":" in newname and newname.endswith("_total"): - newname = newname[: -len("_total")] - output.append(sample_line(s, newname)) - - for suffix, lines in sorted(om_samples.items()): - if emit_help: - output.append( - "# HELP {}{} {}\n".format( - mname, - suffix, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname}{suffix} gauge\n") - output.extend(lines) - - # Get rid of the weird colon things while we're at it - if mtype == "counter": - mnewname = mnewname.replace(":total", "") - mnewname = mnewname.replace(":", "_") - - if mname == mnewname: - continue - - # Also output in the new format, if it's different. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mnewname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mnewname} {mtype}\n") - - for s in metric.samples: - # Get rid of the OpenMetrics specific samples (we should already have - # dealt with them above anyway.) - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - break - else: - sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name) - output.append( - sample_line(s, sample_name.replace(":total", "").replace(":", "_")) - ) - - return "".join(output).encode("utf-8") - - -class MetricsHandler(BaseHTTPRequestHandler): - """HTTP handler that gives metrics from ``REGISTRY``.""" - - registry = REGISTRY - - def do_GET(self) -> None: - registry = self.registry - params = parse_qs(urlparse(self.path).query) - - if "help" in params: - emit_help = True - else: - emit_help = False - - try: - output = generate_latest(registry, emit_help=emit_help) - except Exception: - self.send_error(500, "error generating metric output") - raise - try: - self.send_response(200) - self.send_header("Content-Type", CONTENT_TYPE_LATEST) - self.send_header("Content-Length", str(len(output))) - self.end_headers() - self.wfile.write(output) - except BrokenPipeError as e: - logger.warning( - "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e - ) - - def log_message(self, format: str, *args: Any) -> None: - """Log nothing.""" - - @classmethod - def factory(cls, registry: CollectorRegistry) -> Type: - """Returns a dynamic MetricsHandler class tied - to the passed registry. - """ - # This implementation relies on MetricsHandler.registry - # (defined above and defaulted to REGISTRY). - - # As we have unicode_literals, we need to create a str() - # object for type(). - cls_name = str(cls.__name__) - MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry}) - return MyMetricsHandler - - -class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): - """Thread per request HTTP server.""" - - # Make worker threads "fire and forget". Beginning with Python 3.7 this - # prevents a memory leak because ``ThreadingMixIn`` starts to gather all - # non-daemon threads in a list in order to join on them at server close. - # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the - # same as Python 3.7's ``ThreadingHTTPServer``. - daemon_threads = True - - -def start_http_server( - port: int, addr: str = "", registry: CollectorRegistry = REGISTRY -) -> None: - """Starts an HTTP server for prometheus metrics as a daemon thread""" - CustomMetricsHandler = MetricsHandler.factory(registry) - httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) - t = threading.Thread(target=httpd.serve_forever) - t.daemon = True - t.start() - - -class MetricsResource(Resource): - """ - Twisted ``Resource`` that serves prometheus metrics. - """ - - isLeaf = True - - def __init__(self, registry: CollectorRegistry = REGISTRY): - self.registry = registry - - def render_GET(self, request: Request) -> bytes: - request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) - response = generate_latest(self.registry) - request.setHeader(b"Content-Length", str(len(response))) - return response diff --git a/synapse/metrics/_twisted_exposition.py b/synapse/metrics/_twisted_exposition.py new file mode 100644 index 0000000000..0abcd14953 --- /dev/null +++ b/synapse/metrics/_twisted_exposition.py @@ -0,0 +1,38 @@ +# Copyright 2015-2019 Prometheus Python Client Developers +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from prometheus_client import REGISTRY, CollectorRegistry, generate_latest + +from twisted.web.resource import Resource +from twisted.web.server import Request + +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" + + +class MetricsResource(Resource): + """ + Twisted ``Resource`` that serves prometheus metrics. + """ + + isLeaf = True + + def __init__(self, registry: CollectorRegistry = REGISTRY): + self.registry = registry + + def render_GET(self, request: Request) -> bytes: + request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) + response = generate_latest(self.registry) + request.setHeader(b"Content-Length", str(len(response))) + return response diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 088fbb247b..6f1135eef4 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from prometheus_client import generate_latest -from synapse.metrics import REGISTRY, generate_latest +from synapse.metrics import REGISTRY from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -53,8 +54,8 @@ class ExtremStatisticsTestCase(HomeserverTestCase): items = list( filter( - lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY, emit_help=False).split(b"\n"), + lambda x: b"synapse_forward_extremities_" in x and b"# HELP" not in x, + generate_latest(REGISTRY).split(b"\n"), ) ) -- cgit 1.5.1 From f6c74d1cb2ed966802b01a2b037f09ce7a842c18 Mon Sep 17 00:00:00 2001 From: Benjamin Kampmann Date: Thu, 24 Nov 2022 09:10:51 +0000 Subject: Implement message forward pagination from start when no from is given, fixes #12383 (#14149) Fixes https://github.com/matrix-org/synapse/issues/12383 --- changelog.d/14149.bugfix | 1 + synapse/handlers/pagination.py | 6 ++++++ synapse/streams/events.py | 13 +++++++++++++ tests/rest/admin/test_room.py | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+) create mode 100644 changelog.d/14149.bugfix (limited to 'tests') diff --git a/changelog.d/14149.bugfix b/changelog.d/14149.bugfix new file mode 100644 index 0000000000..b31c658266 --- /dev/null +++ b/changelog.d/14149.bugfix @@ -0,0 +1 @@ +Fix #12383: paginate room messages from the start if no from is given. Contributed by @gnunicorn . \ No newline at end of file diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a4ca9cb8b4..c572508a02 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -448,6 +448,12 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token + elif pagin_config.direction == "f": + from_token = ( + await self.hs.get_event_sources().get_start_token_for_pagination( + room_id + ) + ) else: from_token = ( await self.hs.get_event_sources().get_current_token_for_pagination( diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f331e1af16..619eb7f601 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -73,6 +73,19 @@ class EventSources: ) return token + @trace + async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: + """Get the start token for a given room to be used to paginate + events. + + The returned token does not have the current values for fields other + than `room`, since they are not used during pagination. + + Returns: + The start token for pagination. + """ + return StreamToken.START + @trace async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: """Get the current token for a given room to be used to paginate diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d156be82b0..e0f5d54aba 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1857,6 +1857,46 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertIn("end", channel.json_body) + def test_room_messages_backward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in backwards, this is the first event + self.assertEqual(chunk[0]["event_id"], latest_event_id) + + def test_room_messages_forward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in forward, this is the last event + self.assertEqual(chunk[5]["event_id"], latest_event_id) + def test_room_messages_purge(self) -> None: """Test room messages can be retrieved by an admin that isn't in the room.""" store = self.hs.get_datastores().main -- cgit 1.5.1 From 09de2aecb05cb46e0513396e2675b24c8beedb68 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 25 Nov 2022 19:16:50 +0400 Subject: Add support for handling avatar with SSO login (#13917) This commit adds support for handling a provided avatar picture URL when logging in via SSO. Signed-off-by: Ashish Kumar Fixes #9357. --- changelog.d/13917.feature | 1 + docs/usage/configuration/config_documentation.md | 9 +- mypy.ini | 4 +- synapse/handlers/oidc.py | 7 ++ synapse/handlers/sso.py | 111 +++++++++++++++++ tests/handlers/test_sso.py | 145 +++++++++++++++++++++++ 6 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13917.feature create mode 100644 tests/handlers/test_sso.py (limited to 'tests') diff --git a/changelog.d/13917.feature b/changelog.d/13917.feature new file mode 100644 index 0000000000..4eb942ab38 --- /dev/null +++ b/changelog.d/13917.feature @@ -0,0 +1 @@ +Adds support for handling avatar in SSO login. Contributed by @ashfame. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index fae2771fad..749af12aac 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2968,10 +2968,17 @@ Options for each entry include: For the default provider, the following settings are available: - * subject_claim: name of the claim containing a unique identifier + * `subject_claim`: name of the claim containing a unique identifier for the user. Defaults to 'sub', which OpenID Connect compliant providers should provide. + * `picture_claim`: name of the claim containing an url for the user's profile picture. + Defaults to 'picture', which OpenID Connect compliant providers should provide + and has to refer to a direct image file such as PNG, JPEG, or GIF image file. + + Currently only supported in monolithic (single-process) server configurations + where the media repository runs within the Synapse process. + * `localpart_template`: Jinja2 template for the localpart of the MXID. If this is not set, the user will be prompted to choose their own username (see the documentation for the `sso_auth_account_details.html` diff --git a/mypy.ini b/mypy.ini index 25b3c93748..0b6e7df267 100644 --- a/mypy.ini +++ b/mypy.ini @@ -119,6 +119,9 @@ disallow_untyped_defs = True [mypy-tests.storage.test_profile] disallow_untyped_defs = True +[mypy-tests.handlers.test_sso] +disallow_untyped_defs = True + [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True @@ -137,7 +140,6 @@ disallow_untyped_defs = False [mypy-tests.utils] disallow_untyped_defs = True - ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. ;; The `typeshed` project maintains stubs here: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 41c675f408..03de6a4ba6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict): localpart: Optional[str] confirm_localpart: bool display_name: Optional[str] + picture: Optional[str] # may be omitted by older `OidcMappingProviders` emails: List[str] @@ -1520,6 +1521,7 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: subject_claim: str + picture_claim: str localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1539,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") + picture_claim = config.get("picture_claim", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1572,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return JinjaOidcMappingConfig( subject_claim=subject_claim, + picture_claim=picture_claim, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1611,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) + picture = userinfo.get("picture") + return UserAttributeDict( localpart=localpart, display_name=display_name, emails=emails, + picture=picture, confirm_localpart=self._config.confirm_localpart, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e1c0bff1b2..44e70fc4b8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import hashlib +import io import logging from typing import ( TYPE_CHECKING, @@ -138,6 +140,7 @@ class UserAttributes: localpart: Optional[str] confirm_localpart: bool = False display_name: Optional[str] = None + picture: Optional[str] = None emails: Collection[str] = attr.Factory(list) @@ -196,6 +199,10 @@ class SsoHandler: self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() + self._media_repo = ( + hs.get_media_repository() if hs.config.media.can_load_media_repo else None + ) + self._http_client = hs.get_proxied_blacklisted_http_client() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. @@ -495,6 +502,8 @@ class SsoHandler: await self._profile_handler.set_displayname( user_id_obj, requester, attributes.display_name, True ) + if attributes.picture: + await self.set_avatar(user_id, attributes.picture) await self._auth_handler.complete_sso_login( user_id, @@ -703,8 +712,110 @@ class SsoHandler: await self._store.record_user_external_id( auth_provider_id, remote_user_id, registered_user_id ) + + # Set avatar, if available + if attributes.picture: + await self.set_avatar(registered_user_id, attributes.picture) + return registered_user_id + async def set_avatar(self, user_id: str, picture_https_url: str) -> bool: + """Set avatar of the user. + + This downloads the image file from the URL provided, stores that in + the media repository and then sets the avatar on the user's profile. + + It can detect if the same image is being saved again and bails early by storing + the hash of the file in the `upload_name` of the avatar image. + + Currently, it only supports server configurations which run the media repository + within the same process. + + It silently fails and logs a warning by raising an exception and catching it + internally if: + * it is unable to fetch the image itself (non 200 status code) or + * the image supplied is bigger than max allowed size or + * the image type is not one of the allowed image types. + + Args: + user_id: matrix user ID in the form @localpart:domain as a string. + + picture_https_url: HTTPS url for the picture image file. + + Returns: `True` if the user's avatar has been successfully set to the image at + `picture_https_url`. + """ + if self._media_repo is None: + logger.info( + "failed to set user avatar because out-of-process media repositories " + "are not supported yet " + ) + return False + + try: + uid = UserID.from_string(user_id) + + def is_allowed_mime_type(content_type: str) -> bool: + if ( + self._profile_handler.allowed_avatar_mimetypes + and content_type + not in self._profile_handler.allowed_avatar_mimetypes + ): + return False + return True + + # download picture, enforcing size limit & mime type check + picture = io.BytesIO() + + content_length, headers, uri, code = await self._http_client.get_file( + url=picture_https_url, + output_stream=picture, + max_size=self._profile_handler.max_avatar_size, + is_allowed_content_type=is_allowed_mime_type, + ) + + if code != 200: + raise Exception( + "GET request to download sso avatar image returned {}".format(code) + ) + + # upload name includes hash of the image file's content so that we can + # easily check if it requires an update or not, the next time user logs in + upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest() + + # bail if user already has the same avatar + profile = await self._profile_handler.get_profile(user_id) + if profile["avatar_url"] is not None: + server_name = profile["avatar_url"].split("/")[-2] + media_id = profile["avatar_url"].split("/")[-1] + if server_name == self._server_name: + media = await self._media_repo.store.get_local_media(media_id) + if media is not None and upload_name == media["upload_name"]: + logger.info("skipping saving the user avatar") + return True + + # store it in media repository + avatar_mxc_url = await self._media_repo.create_content( + media_type=headers[b"Content-Type"][0].decode("utf-8"), + upload_name=upload_name, + content=picture, + content_length=content_length, + auth_user=uid, + ) + + # save it as user avatar + await self._profile_handler.set_avatar_url( + uid, + create_requester(uid), + str(avatar_mxc_url), + ) + + logger.info("successfully saved the user avatar") + return True + except Exception: + logger.warning("failed to save the user avatar") + return False + async def complete_sso_ui_auth_request( self, auth_provider_id: str, diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py new file mode 100644 index 0000000000..137deab138 --- /dev/null +++ b/tests/handlers/test_sso.py @@ -0,0 +1,145 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import BinaryIO, Callable, Dict, List, Optional, Tuple +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers + +from synapse.api.errors import Codes, SynapseError +from synapse.http.client import RawHeaders +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.test_utils import SMALL_PNG, FakeResponse + + +class TestSSOHandler(unittest.HomeserverTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + self.http_client.user_agent = b"Synapse Test" + hs = self.setup_test_homeserver( + proxied_blacklisted_http_client=self.http_client + ) + return hs + + async def test_set_avatar(self) -> None: + """Tests successfully setting the avatar of a newly created user""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # Ensure avatar is set on this newly created user, + # so no need to compare for the exact image + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertIsNot(profile["avatar_url"], None) + + @unittest.override_config({"max_avatar_size": 1}) + async def test_set_avatar_too_big_image(self) -> None: + """Tests that saving an avatar fails when it is too big""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]}) + async def test_set_avatar_incorrect_mime_type(self) -> None: + """Tests that saving an avatar fails when its mime type is not allowed""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + async def test_skip_saving_avatar_when_not_changed(self) -> None: + """Tests whether saving of avatar correctly skips if the avatar hasn't + changed""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + # set avatar for the first time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # get avatar picture for comparison after another attempt + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + url_to_match = profile["avatar_url"] + + # set same avatar for the second time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # compare avatar picture's url from previous step + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertEqual(profile["avatar_url"], url_to_match) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + + fake_response = FakeResponse(code=404) + if url == "http://my.server/me.png": + fake_response = FakeResponse( + code=200, + headers=Headers( + {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]} + ), + body=SMALL_PNG, + ) + + if max_size is not None and max_size < len(SMALL_PNG): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, + ) + + if is_allowed_content_type and not is_allowed_content_type("image/png"): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + ( + "Requested file's content type not allowed for this operation: %s" + % "image/png" + ), + ) + + output_stream.write(fake_response.body) + + return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200 -- cgit 1.5.1 From d748bbc8f8268d2e8457374d529adafb20b9f5f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Nov 2022 09:40:17 -0500 Subject: Include thread information when sending receipts over federation. (#14466) Include the thread_id field when sending read receipts over federation. This might result in the same user having multiple read receipts per-room, meaning multiple EDUs must be sent to encapsulate those receipts. This restructures the PerDestinationQueue APIs to support multiple receipt EDUs, queue_read_receipt now becomes linear time in the number of queued threaded receipts in the room for the given user, it is expected this is a small number since receipt EDUs are sent as filler in transactions. --- changelog.d/14466.bugfix | 1 + synapse/federation/sender/per_destination_queue.py | 183 ++++++++++++++------- synapse/handlers/receipts.py | 1 - tests/federation/test_federation_sender.py | 77 +++++++++ 4 files changed, 198 insertions(+), 64 deletions(-) create mode 100644 changelog.d/14466.bugfix (limited to 'tests') diff --git a/changelog.d/14466.bugfix b/changelog.d/14466.bugfix new file mode 100644 index 0000000000..82f6e6b68e --- /dev/null +++ b/changelog.d/14466.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where a receipt's thread ID was not sent over federation. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 3ae5e8634c..5af2784f1e 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -35,7 +35,7 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import SynapseTags, set_tag from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt +from synapse.types import JsonDict, ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.visibility import filter_events_for_server @@ -136,8 +136,11 @@ class PerDestinationQueue: # destination self._pending_presence: Dict[str, UserPresenceState] = {} - # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {} + # List of room_id -> receipt_type -> user_id -> receipt_dict, + # + # Each receipt can only have a single receipt per + # (room ID, receipt type, user ID, thread ID) tuple. + self._pending_receipt_edus: List[Dict[str, Dict[str, Dict[str, dict]]]] = [] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -202,17 +205,53 @@ class PerDestinationQueue: Args: receipt: receipt to be queued """ - self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( - receipt.receipt_type, {} - )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} + serialized_receipt: JsonDict = { + "event_ids": receipt.event_ids, + "data": receipt.data, + } + if receipt.thread_id is not None: + serialized_receipt["data"]["thread_id"] = receipt.thread_id + + # Find which EDU to add this receipt to. There's three situations depending + # on the (room ID, receipt type, user, thread ID) tuple: + # + # 1. If it fully matches, clobber the information. + # 2. If it is missing, add the information. + # 3. If the subset tuple of (room ID, receipt type, user) matches, check + # the next EDU (or add a new EDU). + for edu in self._pending_receipt_edus: + receipt_content = edu.setdefault(receipt.room_id, {}).setdefault( + receipt.receipt_type, {} + ) + # If this room ID, receipt type, user ID is not in this EDU, OR if + # the full tuple matches, use the current EDU. + if ( + receipt.user_id not in receipt_content + or receipt_content[receipt.user_id].get("thread_id") + == receipt.thread_id + ): + receipt_content[receipt.user_id] = serialized_receipt + break + + # If no matching EDU was found, create a new one. + else: + self._pending_receipt_edus.append( + { + receipt.room_id: { + receipt.receipt_type: {receipt.user_id: serialized_receipt} + } + } + ) def flush_read_receipts_for_room(self, room_id: str) -> None: - # if we don't have any read-receipts for this room, it may be that we've already - # sent them out, so we don't need to flush. - if room_id not in self._pending_rrs: - return - self._rrs_pending_flush = True - self.attempt_new_transaction() + # If there are any pending receipts for this room then force-flush them + # in a new transaction. + for edu in self._pending_receipt_edus: + if room_id in edu: + self._rrs_pending_flush = True + self.attempt_new_transaction() + # No use in checking remaining EDUs if the room was found. + break def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu @@ -351,7 +390,7 @@ class PerDestinationQueue: self._pending_edus = [] self._pending_edus_keyed = {} self._pending_presence = {} - self._pending_rrs = {} + self._pending_receipt_edus = [] self._start_catching_up() except FederationDeniedError as e: @@ -543,22 +582,27 @@ class PerDestinationQueue: self._destination, last_successful_stream_ordering ) - def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: - if not self._pending_rrs: + def _get_receipt_edus(self, force_flush: bool, limit: int) -> Iterable[Edu]: + if not self._pending_receipt_edus: return if not force_flush and not self._rrs_pending_flush: # not yet time for this lot return - edu = Edu( - origin=self._server_name, - destination=self._destination, - edu_type=EduTypes.RECEIPT, - content=self._pending_rrs, - ) - self._pending_rrs = {} - self._rrs_pending_flush = False - yield edu + # Send at most limit EDUs for receipts. + for content in self._pending_receipt_edus[:limit]: + yield Edu( + origin=self._server_name, + destination=self._destination, + edu_type=EduTypes.RECEIPT, + content=content, + ) + self._pending_receipt_edus = self._pending_receipt_edus[limit:] + + # If there are still pending read-receipts, don't reset the pending flush + # flag. + if not self._pending_receipt_edus: + self._rrs_pending_flush = False def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus @@ -645,27 +689,61 @@ class _TransactionQueueManager: async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: # First we calculate the EDUs we want to send, if any. - # We start by fetching device related EDUs, i.e device updates and to - # device messages. We have to keep 2 free slots for presence and rr_edus. - device_edu_limit = MAX_EDUS_PER_TRANSACTION - 2 + # There's a maximum number of EDUs that can be sent with a transaction, + # generally device updates and to-device messages get priority, but we + # want to ensure that there's room for some other EDUs as well. + # + # This is done by: + # + # * Add a presence EDU, if one exists. + # * Add up-to a small limit of read receipt EDUs. + # * Add to-device EDUs, but leave some space for device list updates. + # * Add device list updates EDUs. + # * If there's any remaining room, add other EDUs. + pending_edus = [] + + # Add presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type=EduTypes.PRESENCE, + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} - # We prioritize to-device messages so that existing encryption channels + # Add read receipt EDUs. + pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5)) + edu_limit = MAX_EDUS_PER_TRANSACTION - len(pending_edus) + + # Next, prioritize to-device messages so that existing encryption channels # work. We also keep a few slots spare (by reducing the limit) so that # we can still trickle out some device list updates. ( to_device_edus, device_stream_id, - ) = await self.queue._get_to_device_message_edus(device_edu_limit - 10) + ) = await self.queue._get_to_device_message_edus(edu_limit - 10) if to_device_edus: self._device_stream_id = device_stream_id else: self.queue._last_device_stream_id = device_stream_id - device_edu_limit -= len(to_device_edus) + pending_edus.extend(to_device_edus) + edu_limit -= len(to_device_edus) + # Add device list update EDUs. device_update_edus, dev_list_id = await self.queue._get_device_update_edus( - device_edu_limit + edu_limit ) if device_update_edus: @@ -673,40 +751,17 @@ class _TransactionQueueManager: else: self.queue._last_device_list_stream_id = dev_list_id - pending_edus = device_update_edus + to_device_edus - - # Now add the read receipt EDU. - pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) - - # And presence EDU. - if self.queue._pending_presence: - pending_edus.append( - Edu( - origin=self.queue._server_name, - destination=self.queue._destination, - edu_type=EduTypes.PRESENCE, - content={ - "push": [ - format_user_presence_state( - presence, self.queue._clock.time_msec() - ) - for presence in self.queue._pending_presence.values() - ] - }, - ) - ) - self.queue._pending_presence = {} + pending_edus.extend(device_update_edus) + edu_limit -= len(device_update_edus) # Finally add any other types of EDUs if there is room. - pending_edus.extend( - self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self.queue._pending_edus_keyed - ): + other_edus = self.queue._pop_pending_edus(edu_limit) + pending_edus.extend(other_edus) + edu_limit -= len(other_edus) + while edu_limit > 0 and self.queue._pending_edus_keyed: _, val = self.queue._pending_edus_keyed.popitem() pending_edus.append(val) + edu_limit -= 1 # Now we look for any PDUs to send, by getting up to 50 PDUs from the # queue @@ -717,8 +772,10 @@ class _TransactionQueueManager: # if we've decided to send a transaction anyway, and we have room, we # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + if edu_limit: + pending_edus.extend( + self.queue._get_receipt_edus(force_flush=True, limit=edu_limit) + ) if self._pdus: self._last_stream_ordering = self._pdus[ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index ac01582442..6a4fed1156 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -92,7 +92,6 @@ class ReceiptsHandler: continue # Check if these receipts apply to a thread. - thread_id = None data = user_values.get("data", {}) thread_id = data.get("thread_id") # If the thread ID is invalid, consider it missing. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index f1e357764f..01f147418b 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -83,6 +83,83 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) + def test_send_receipts_thread(self): + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) + mock_send_transaction.return_value = make_awaitable({}) + + # Create receipts for: + # + # * The same room / user on multiple threads. + # * A different user in the same room. + sender = self.hs.get_federation_sender() + for user, thread in ( + ("alice", None), + ("alice", "thread"), + ("bob", None), + ("bob", "diff-thread"), + ): + receipt = ReadReceipt( + "room_id", + "m.read", + user, + ["event_id"], + thread_id=thread, + data={"ts": 1234}, + ) + self.successResultOf( + defer.ensureDeferred(sender.send_read_receipt(receipt)) + ) + + self.pump() + + # expect a call to send_transaction with two EDUs to separate threads. + mock_send_transaction.assert_called_once() + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + # Note that the ordering of the EDUs doesn't matter. + self.assertCountEqual( + data["edus"], + [ + { + "edu_type": EduTypes.RECEIPT, + "content": { + "room_id": { + "m.read": { + "alice": { + "event_ids": ["event_id"], + "data": {"ts": 1234, "thread_id": "thread"}, + }, + "bob": { + "event_ids": ["event_id"], + "data": {"ts": 1234, "thread_id": "diff-thread"}, + }, + } + } + }, + }, + { + "edu_type": EduTypes.RECEIPT, + "content": { + "room_id": { + "m.read": { + "alice": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, + }, + "bob": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, + }, + } + } + }, + }, + ], + ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but -- cgit 1.5.1 From 1183c372fa9da01b2667f1b83dab958dad432c68 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 28 Nov 2022 11:17:29 -0500 Subject: Use `device_one_time_keys_count` to match MSC3202 (#14565) * Use `device_one_time_keys_count` to match MSC3202 Rename the `device_one_time_key_counts` key in responses to `device_one_time_keys_count` to match the name specified by MSC3202. Also change related variable/class names for consistency. Signed-off-by: Andrew Ferrazzutti * Update changelog.d/14565.misc * Revert name change for `one_time_key_counts` key as this is a different key altogether from `device_one_time_keys_count`, which is used for `/sync` instead of appservice transactions. Signed-off-by: Andrew Ferrazzutti --- changelog.d/14565.misc | 1 + synapse/appservice/__init__.py | 10 +++++----- synapse/appservice/api.py | 11 +++++++---- synapse/appservice/scheduler.py | 16 ++++++++-------- synapse/handlers/sync.py | 6 +++--- synapse/storage/databases/main/appservice.py | 10 +++++----- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- tests/appservice/test_scheduler.py | 6 +++--- tests/handlers/test_appservice.py | 4 ++-- 9 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 changelog.d/14565.misc (limited to 'tests') diff --git a/changelog.d/14565.misc b/changelog.d/14565.misc new file mode 100644 index 0000000000..19a62b036c --- /dev/null +++ b/changelog.d/14565.misc @@ -0,0 +1 @@ +In application service transactions that include the experimental `org.matrix.msc3202.device_one_time_key_counts` key, include a duplicate key of `org.matrix.msc3202.device_one_time_keys_count` to match the name proposed by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/blob/travis/msc/otk-dl-appservice/proposals/3202-encrypted-appservices.md). diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 500bdde3a9..bf4e6c629b 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -32,9 +32,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Type for the `device_one_time_key_counts` field in an appservice transaction +# Type for the `device_one_time_keys_count` field in an appservice transaction # user ID -> {device ID -> {algorithm -> count}} -TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] +TransactionOneTimeKeysCount = Dict[str, Dict[str, Dict[str, int]]] # Type for the `device_unused_fallback_key_types` field in an appservice transaction # user ID -> {device ID -> [algorithm]} @@ -376,7 +376,7 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ): @@ -385,7 +385,7 @@ class AppServiceTransaction: self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages - self.one_time_key_counts = one_time_key_counts + self.one_time_keys_count = one_time_keys_count self.unused_fallback_keys = unused_fallback_keys self.device_list_summary = device_list_summary @@ -402,7 +402,7 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, - one_time_key_counts=self.one_time_key_counts, + one_time_keys_count=self.one_time_keys_count, unused_fallback_keys=self.unused_fallback_keys, device_list_summary=self.device_list_summary, txn_id=self.id, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 60774b240d..edafd433cd 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.events import EventBase @@ -262,7 +262,7 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, @@ -310,10 +310,13 @@ class ApplicationServiceApi(SimpleHttpClient): # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: - if one_time_key_counts: + if one_time_keys_count: body[ "org.matrix.msc3202.device_one_time_key_counts" - ] = one_time_key_counts + ] = one_time_keys_count + body[ + "org.matrix.msc3202.device_one_time_keys_count" + ] = one_time_keys_count if unused_fallback_keys: body[ "org.matrix.msc3202.device_unused_fallback_key_types" diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 430ffbcd1f..7b562795a3 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -64,7 +64,7 @@ from typing import ( from synapse.appservice import ( ApplicationService, ApplicationServiceState, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.appservice.api import ApplicationServiceApi @@ -258,7 +258,7 @@ class _ServiceQueuer: ): return - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None if ( @@ -269,7 +269,7 @@ class _ServiceQueuer: # for the users which are mentioned in this transaction, # as well as the appservice's sender. ( - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, ) = await self._compute_msc3202_otk_counts_and_fallback_keys( service, events, ephemeral, to_device_messages_to_send @@ -281,7 +281,7 @@ class _ServiceQueuer: events, ephemeral, to_device_messages_to_send, - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, device_list_summary, ) @@ -296,7 +296,7 @@ class _ServiceQueuer: events: Iterable[EventBase], ephemerals: Iterable[JsonDict], to_device_messages: Iterable[JsonDict], - ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, - first computes a list of application services users that may have @@ -367,7 +367,7 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: @@ -380,7 +380,7 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -397,7 +397,7 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], - one_time_key_counts=one_time_key_counts or {}, + one_time_keys_count=one_time_keys_count or {}, unused_fallback_keys=unused_fallback_keys or {}, device_list_summary=device_list_summary or DeviceListUpdates(), ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 259456b55d..c8858b22dd 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1426,14 +1426,14 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts: JsonDict = {} + one_time_keys_count: JsonDict = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: # * no change in OTK count since the provided since token # * the server has zero OTKs left for this device # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 - one_time_key_counts = await self.store.count_e2e_one_time_keys( + one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) unused_fallback_key_types = ( @@ -1463,7 +1463,7 @@ class SyncHandler: archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, - device_one_time_keys_count=one_time_key_counts, + device_one_time_keys_count=one_time_keys_count, device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 25da0c56c5..c2c8018ee2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,7 +20,7 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices @@ -260,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: @@ -273,7 +273,7 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -299,7 +299,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, - one_time_key_counts=one_time_key_counts, + one_time_keys_count=one_time_keys_count, unused_fallback_keys=unused_fallback_keys, device_list_summary=device_list_summary, ) @@ -379,7 +379,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index cf33e73e2b..643c47d608 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -33,7 +33,7 @@ from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace @@ -514,7 +514,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def count_bulk_e2e_one_time_keys_for_as( self, user_ids: Collection[str] - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: """ Counts, in bulk, the one-time keys for all the users specified. Intended to be used by application services for populating OTK counts in @@ -528,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _count_bulk_e2e_one_time_keys_txn( txn: LoggingTransaction, - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "user_id", user_ids ) @@ -541,7 +541,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ txn.execute(sql, user_parameters) - result: TransactionOneTimeKeyCounts = {} + result: TransactionOneTimeKeysCount = {} for user_id, device_id, algorithm, count in txn: # We deliberately construct empty dictionaries for diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0b22afdc75..0a1ae83a2b 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -69,7 +69,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -96,7 +96,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -125,7 +125,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 144e49d0fd..9ed26d87a7 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -25,7 +25,7 @@ import synapse.storage from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.handlers.appservice import ApplicationServicesHandler @@ -1123,7 +1123,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Capture what was sent as an AS transaction. self.send_mock.assert_called() last_args, _last_kwargs = self.send_mock.call_args - otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS] unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ self.ARG_FALLBACK_KEYS ] -- cgit 1.5.1 From 8f10c8b054fc970838be9ae6f1f5aea95f166c98 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 28 Nov 2022 15:54:18 -0600 Subject: Move MSC3030 `/timestamp_to_event` endpoint to stable v1 location (#14471) Fix https://github.com/matrix-org/synapse/issues/14390 - Client API: `/_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir=` -> `/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=` - Federation API: `/_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir=` -> `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=` Complement test changes: https://github.com/matrix-org/complement/pull/559 --- changelog.d/14471.feature | 1 + docker/complement/conf/workers-shared-extra.yaml.j2 | 2 -- docker/configure_workers_and_start.py | 2 ++ docs/workers.md | 2 ++ scripts-dev/complement.sh | 6 +++--- synapse/config/experimental.py | 3 --- synapse/federation/federation_client.py | 12 +++++++++++- synapse/federation/transport/client.py | 5 ++--- synapse/federation/transport/server/__init__.py | 8 -------- synapse/federation/transport/server/federation.py | 3 +-- synapse/rest/client/room.py | 10 +++------- synapse/rest/client/versions.py | 2 -- tests/rest/client/test_rooms.py | 7 +------ 13 files changed, 26 insertions(+), 37 deletions(-) create mode 100644 changelog.d/14471.feature (limited to 'tests') diff --git a/changelog.d/14471.feature b/changelog.d/14471.feature new file mode 100644 index 0000000000..a0e0c74f1a --- /dev/null +++ b/changelog.d/14471.feature @@ -0,0 +1 @@ +Move MSC3030 `/timestamp_to_event` endpoints to stable `v1` location (`/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=`, `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=`). diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 883a87159c..ca640c343b 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -100,8 +100,6 @@ experimental_features: # client-side support for partial state in /send_join responses faster_joins: true {% endif %} - # Enable jump to date endpoint - msc3030_enabled: true # Filtering /messages by relation type. msc3874_enabled: true diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index c1e1544536..58c62f2231 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -140,6 +140,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event", "^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms", "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", + "^/_matrix/client/v1/rooms/.*/timestamp_to_event$", "^/_matrix/client/(api/v1|r0|v3|unstable)/search", ], "shared_extra_conf": {}, @@ -163,6 +164,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/federation/(v1|v2)/invite/", "^/_matrix/federation/(v1|v2)/query_auth/", "^/_matrix/federation/(v1|v2)/event_auth/", + "^/_matrix/federation/v1/timestamp_to_event/", "^/_matrix/federation/(v1|v2)/exchange_third_party_invite/", "^/_matrix/federation/(v1|v2)/user/devices/", "^/_matrix/federation/(v1|v2)/get_groups_publicised$", diff --git a/docs/workers.md b/docs/workers.md index 27e54c5846..2b65acb5ed 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -191,6 +191,7 @@ information. ^/_matrix/federation/(v1|v2)/send_leave/ ^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/v1/event_auth/ + ^/_matrix/federation/v1/timestamp_to_event/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ ^/_matrix/key/v2/query @@ -218,6 +219,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ + ^/_matrix/client/v1/rooms/.*/timestamp_to_event$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Encryption requests diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 803c6ce92d..7744b47097 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -162,9 +162,9 @@ else # We only test faster room joins on monoliths, because they are purposefully # being developed without worker support to start with. # - # The tests for importing historical messages (MSC2716) and jump to date (MSC3030) - # also only pass with monoliths, currently. - test_tags="$test_tags,faster_joins,msc2716,msc3030" + # The tests for importing historical messages (MSC2716) also only pass with monoliths, + # currently. + test_tags="$test_tags,faster_joins,msc2716" fi diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d4b71d1673..a503abf364 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -53,9 +53,6 @@ class ExperimentalConfig(Config): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) - # MSC3030 (Jump to date API endpoint) - self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c4c0bc7315..8bccc9c60d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1691,9 +1691,19 @@ class FederationClient(FederationBase): # to return events on *both* sides of the timestamp to # help reconcile the gap faster. _timestamp_to_event_from_destination, + # Since this endpoint is new, we should try other servers before giving up. + # We can safely remove this in a year (remove after 2023-11-16). + failover_on_unknown_endpoint=True, ) return timestamp_to_event_response - except SynapseError: + except SynapseError as e: + logger.warn( + "timestamp_to_event(room_id=%s, timestamp=%s, direction=%s): encountered error when trying to fetch from destinations: %s", + room_id, + timestamp, + direction, + e, + ) return None async def _timestamp_to_event_from_destination( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index a3cfc701cd..77f1f39cac 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -185,9 +185,8 @@ class TransportLayerClient: Raises: Various exceptions when the request fails """ - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, - "/org.matrix.msc3030/timestamp_to_event/%s", + path = _create_v1_path( + "/timestamp_to_event/%s", room_id, ) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 50623cd385..2725f53cf6 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -25,7 +25,6 @@ from synapse.federation.transport.server._base import ( from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, - FederationTimestampLookupServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -291,13 +290,6 @@ def register_servlets( ) for servletclass in SERVLET_GROUPS[servlet_group]: - # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled - if ( - servletclass == FederationTimestampLookupServlet - and not hs.config.experimental.msc3030_enabled - ): - continue - # Only allow the `/account_status` servlet if msc3720 is enabled if ( servletclass == FederationAccountStatusServlet diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 205fd16daa..53e77b4bb6 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -218,14 +218,13 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= + GET /_matrix/federation/v1/timestamp_to_event/?ts=&dir= { "event_id": ... } """ PATH = "/timestamp_to_event/(?P[^/]*)/?" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" async def on_GET( self, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 91cb791139..636cc62877 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1284,17 +1284,14 @@ class TimestampLookupRestServlet(RestServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= + GET /_matrix/client/v1/rooms//timestamp_to_event?ts=&dir= { "event_id": ... } """ PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3030" - "/rooms/(?P[^/]*)/timestamp_to_event$" - ), + re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/timestamp_to_event$"), ) def __init__(self, hs: "HomeServer"): @@ -1421,8 +1418,7 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) - if hs.config.experimental.msc3030_enabled: - TimestampLookupRestServlet(hs).register(http_server) + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 180a11ef88..3c0a90010b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,8 +101,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 "org.matrix.msc2716": self.config.experimental.msc2716_enabled, - # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 - "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e919e089cb..b4daace556 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -3546,11 +3546,6 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = {"msc3030_enabled": True} - return config - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._storage_controllers = self.hs.get_storage_controllers() @@ -3592,7 +3587,7 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + f"/_matrix/client/v1/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", access_token=self.room_owner_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) -- cgit 1.5.1 From 3da645032722fbf09c1e5efbc51d8c5c78d8a2cd Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 28 Nov 2022 16:29:53 -0700 Subject: Initial support for MSC3931: Room version push rule feature flags (#14520) * Add support for MSC3931: Room Version Supports push rule condition * Create experimental flag for future work, and use it to gate MSC3931 * Changelog entry --- changelog.d/14520.feature | 1 + rust/src/push/evaluator.rs | 26 ++++++++++++++++++++++++++ rust/src/push/mod.rs | 16 ++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 2 ++ synapse/api/room_versions.py | 21 ++++++++++++++++++++- synapse/config/experimental.py | 3 +++ synapse/push/bulk_push_rule_evaluator.py | 6 ++++++ tests/push/test_push_rule_evaluator.py | 2 ++ 8 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14520.feature (limited to 'tests') diff --git a/changelog.d/14520.feature b/changelog.d/14520.feature new file mode 100644 index 0000000000..210acaa8ee --- /dev/null +++ b/changelog.d/14520.feature @@ -0,0 +1 @@ +Add unstable support for an Extensible Events room version (`org.matrix.msc1767.10`) via [MSC1767](https://github.com/matrix-org/matrix-spec-proposals/pull/1767), [MSC3931](https://github.com/matrix-org/matrix-spec-proposals/pull/3931), [MSC3932](https://github.com/matrix-org/matrix-spec-proposals/pull/3932), and [MSC3933](https://github.com/matrix-org/matrix-spec-proposals/pull/3933). \ No newline at end of file diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index cedd42c54d..e8e3d604ee 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -29,6 +29,10 @@ use super::{ lazy_static! { /// Used to parse the `is` clause in the room member count condition. static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex"); + + /// Used to determine which MSC3931 room version feature flags are actually known to + /// the push evaluator. + static ref KNOWN_RVER_FLAGS: Vec = vec![]; } /// Allows running a set of push rules against a particular event. @@ -57,6 +61,13 @@ pub struct PushRuleEvaluator { /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, + + /// If MSC3931 is applicable, the feature flags for the room version. + room_version_feature_flags: Vec, + + /// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same + /// flag as MSC1767 (extensible events core). + msc3931_enabled: bool, } #[pymethods] @@ -70,6 +81,8 @@ impl PushRuleEvaluator { notification_power_levels: BTreeMap, related_events_flattened: BTreeMap>, related_event_match_enabled: bool, + room_version_feature_flags: Vec, + msc3931_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -84,6 +97,8 @@ impl PushRuleEvaluator { sender_power_level, related_events_flattened, related_event_match_enabled, + room_version_feature_flags, + msc3931_enabled, }) } @@ -204,6 +219,15 @@ impl PushRuleEvaluator { false } } + KnownCondition::RoomVersionSupports { feature } => { + if !self.msc3931_enabled { + false + } else { + let flag = feature.to_string(); + KNOWN_RVER_FLAGS.contains(&flag) + && self.room_version_feature_flags.contains(&flag) + } + } }; Ok(result) @@ -362,6 +386,8 @@ fn push_rule_evaluator() { BTreeMap::new(), BTreeMap::new(), true, + vec![], + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index d57800aa4a..eef39f6472 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -277,6 +277,10 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, + #[serde(rename = "org.matrix.msc3931.room_version_supports")] + RoomVersionSupports { + feature: Cow<'static, str>, + }, } impl IntoPy for Condition { @@ -491,6 +495,18 @@ fn test_deserialize_unstable_msc3664_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3931_condition() { + let json = + r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index ceade65ef9..cbeb49663c 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -41,6 +41,8 @@ class PushRuleEvaluator: notification_power_levels: Mapping[str, int], related_events_flattened: Mapping[str, Mapping[str, str]], related_event_match_enabled: bool, + room_version_feature_flags: list[str], + msc3931_enabled: bool, ): ... def run( self, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index e37acb0f1e..1bd1ef3e2b 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional +from typing import Callable, Dict, List, Optional import attr @@ -91,6 +91,12 @@ class RoomVersion: msc3787_knock_restricted_join_rule: bool # MSC3667: Enforce integer power levels msc3667_int_only_power_levels: bool + # MSC3931: Adds a push rule condition for "room version feature flags", making + # some push rules room version dependent. Note that adding a flag to this list + # is not enough to mark it "supported": the push rule evaluator also needs to + # support the flag. Unknown flags are ignored by the evaluator, making conditions + # fail if used. + msc3931_push_features: List[str] class RoomVersions: @@ -111,6 +117,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V2 = RoomVersion( "2", @@ -129,6 +136,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V3 = RoomVersion( "3", @@ -147,6 +155,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V4 = RoomVersion( "4", @@ -165,6 +174,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V5 = RoomVersion( "5", @@ -183,6 +193,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V6 = RoomVersion( "6", @@ -201,6 +212,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -219,6 +231,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V7 = RoomVersion( "7", @@ -237,6 +250,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V8 = RoomVersion( "8", @@ -255,6 +269,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V9 = RoomVersion( "9", @@ -273,6 +288,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) MSC3787 = RoomVersion( "org.matrix.msc3787", @@ -291,6 +307,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V10 = RoomVersion( "10", @@ -309,6 +326,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=True, + msc3931_push_features=[], ) MSC2716v4 = RoomVersion( "org.matrix.msc2716v4", @@ -327,6 +345,7 @@ class RoomVersions: msc2716_redactions=True, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index a503abf364..b3f51fc57d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): # MSC3912: Relation-based redactions. self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) + + # MSC1767 and friends: Extensible Events + self.msc1767_enabled: bool = experimental.get("msc1767_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 75b7e126ca..9cc3da6d91 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -338,6 +338,10 @@ class BulkPushRuleEvaluator: for user_id, level in notification_levels.items(): notification_levels[user_id] = int(level) + room_version_features = event.room_version.msc3931_push_features + if not room_version_features: + room_version_features = [] + evaluator = PushRuleEvaluator( _flatten_dict(event), room_member_count, @@ -345,6 +349,8 @@ class BulkPushRuleEvaluator: notification_levels, related_events, self._related_event_match_enabled, + room_version_features, + self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) users = rules_by_user.keys() diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index fe7c145840..5ababe6a39 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -62,6 +62,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): power_levels.get("notifications", {}), {} if related_events is None else related_events, True, + event.room_version.msc3931_push_features, + True, ) def test_display_name(self) -> None: -- cgit 1.5.1 From c7e29ca277cf60bfdc488b93f4321b046fa6b46f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Nov 2022 10:36:41 +0000 Subject: POC delete stale non-e2e devices for users (#14038) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. Co-authored-by: Patrick Cloke Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14038.misc | 1 + synapse/handlers/device.py | 13 +++++- synapse/storage/databases/main/devices.py | 67 ++++++++++++++++++++++++++++++- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14038.misc (limited to 'tests') diff --git a/changelog.d/14038.misc b/changelog.d/14038.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14038.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b1e55e1b9e..7c4dd8cf5a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -421,6 +421,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Prune the user's device list if they already have a lot of devices. + await self._prune_too_many_devices(user_id) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -452,6 +455,14 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + await self.delete_devices(user_id, device_ids) + async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -481,7 +492,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 534f7fc04a..1e83c62753 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1533,6 +1533,70 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return () + + # We prune everything older than N days. + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + if num_devices > 50: + # If the user has more than 50 devices, then we chose a last seen + # that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Now fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> Collection[str]: + txn.execute(sql, (user_id, max_last_seen)) + return {device_id for device_id, in txn} + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1591,6 +1655,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1636,7 +1701,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Deletes several devices. Args: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 49ad3c1324..a9af1babed 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,6 +169,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -189,7 +191,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1 From c29e2c630624beb0b5557aa0f7ccdcedbe62def1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 29 Nov 2022 17:48:48 +0000 Subject: Revert "POC delete stale non-e2e devices for users (#14038)" (#14582) --- changelog.d/14582.bugfix | 1 + synapse/handlers/device.py | 13 +----- synapse/storage/databases/main/devices.py | 68 +------------------------------ tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 5 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14582.bugfix (limited to 'tests') diff --git a/changelog.d/14582.bugfix b/changelog.d/14582.bugfix new file mode 100644 index 0000000000..caad468e70 --- /dev/null +++ b/changelog.d/14582.bugfix @@ -0,0 +1 @@ +Fix a regression in Synapse 1.73.0rc1 where Synapse's main process would stop responding to HTTP requests when a user with a large number of devices logs in. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7c4dd8cf5a..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -421,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) - # Prune the user's device list if they already have a lot of devices. - await self._prune_too_many_devices(user_id) - if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -455,14 +452,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _prune_too_many_devices(self, user_id: str) -> None: - """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id) - if not device_ids: - return - - await self.delete_devices(user_id, device_ids) - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -492,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 0378035cff..534f7fc04a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1533,71 +1533,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]: - """Check if the user has a lot of devices, and if so return the set of - devices we can prune. - - This does *not* return hidden devices or devices with E2E keys. - """ - - num_devices = await self.db_pool.simple_select_one_onecol( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcol="COALESCE(COUNT(*), 0)", - desc="count_devices", - ) - - # We let users have up to ten devices without pruning. - if num_devices <= 10: - return () - - # We prune everything older than N days. - max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 - - if num_devices > 50: - # If the user has more than 50 devices, then we chose a last seen - # that ensures we keep at most 50 devices. - sql = """ - SELECT last_seen FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen IS NOT NULL - AND key_json IS NULL - ORDER BY last_seen DESC - LIMIT 1 - OFFSET 50 - """ - - rows = await self.db_pool.execute( - "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) - ) - if rows: - max_last_seen = max(rows[0][0], max_last_seen) - - # Now fetch the devices to delete. - sql = """ - SELECT DISTINCT device_id FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen < ? - AND key_json IS NULL - """ - - def check_too_many_devices_for_user_txn( - txn: LoggingTransaction, - ) -> Collection[str]: - txn.execute(sql, (user_id, max_last_seen)) - return {device_id for device_id, in txn} - - return await self.db_pool.runInteraction( - "check_too_many_devices_for_user", - check_too_many_devices_for_user_txn, - ) - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1656,7 +1591,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, - "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1702,7 +1636,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a456bffd63..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": 1000000, + "last_seen_ts": None, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index a9af1babed..49ad3c1324 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,8 +169,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) - last_seen = self.clock.time_msec() - if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -191,7 +189,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": last_seen, + "last_seen": None, }, ], ) -- cgit 1.5.1 From ecb6fe9d9cf8375b760eb727be0e1dec3612e026 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 30 Nov 2022 11:59:57 +0000 Subject: Stop using deprecated `keyIds` param on /key/v2/server (#14525) Fixes #14523. --- changelog.d/14490.feature | 1 + changelog.d/14490.misc | 1 - changelog.d/14525.feature | 1 + synapse/crypto/keyring.py | 107 +++++++++++--------------- tests/crypto/test_keyring.py | 14 +--- tests/rest/key/v2/test_remote_key_resource.py | 5 +- 6 files changed, 47 insertions(+), 82 deletions(-) create mode 100644 changelog.d/14490.feature delete mode 100644 changelog.d/14490.misc create mode 100644 changelog.d/14525.feature (limited to 'tests') diff --git a/changelog.d/14490.feature b/changelog.d/14490.feature new file mode 100644 index 0000000000..c7cb571294 --- /dev/null +++ b/changelog.d/14490.feature @@ -0,0 +1 @@ +Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`. diff --git a/changelog.d/14490.misc b/changelog.d/14490.misc deleted file mode 100644 index c0a4daa885..0000000000 --- a/changelog.d/14490.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash. diff --git a/changelog.d/14525.feature b/changelog.d/14525.feature new file mode 100644 index 0000000000..c7cb571294 --- /dev/null +++ b/changelog.d/14525.feature @@ -0,0 +1 @@ +Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index ed15f88350..69310d9035 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,7 +14,6 @@ import abc import logging -import urllib from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr @@ -813,31 +812,27 @@ class ServerKeyFetcher(BaseV2KeyFetcher): results = {} - async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None: + async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None: server_name = key_to_fetch_item.server_name - key_ids = key_to_fetch_item.key_ids try: - keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) + keys = await self.get_server_verify_keys_v2_direct(server_name) results[server_name] = keys except KeyLookupError as e: - logger.warning( - "Error looking up keys %s from %s: %s", key_ids, server_name, e - ) + logger.warning("Error looking up keys from %s: %s", server_name, e) except Exception: - logger.exception("Error getting keys %s from %s", key_ids, server_name) + logger.exception("Error getting keys from %s", server_name) - await yieldable_gather_results(get_key, keys_to_fetch) + await yieldable_gather_results(get_keys, keys_to_fetch) return results - async def get_server_verify_key_v2_direct( - self, server_name: str, key_ids: Iterable[str] + async def get_server_verify_keys_v2_direct( + self, server_name: str ) -> Dict[str, FetchKeyResult]: """ Args: - server_name: - key_ids: + server_name: Server to request keys from Returns: Map from key ID to lookup result @@ -845,57 +840,41 @@ class ServerKeyFetcher(BaseV2KeyFetcher): Raises: KeyLookupError if there was a problem making the lookup """ - keys: Dict[str, FetchKeyResult] = {} - - for requested_key_id in key_ids: - # we may have found this key as a side-effect of asking for another. - if requested_key_id in keys: - continue - - time_now_ms = self.clock.time_msec() - try: - response = await self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id, safe=""), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - # these both have str() representations which we can't really improve - # upon - raise KeyLookupError(str(e)) - except HttpResponseException as e: - raise KeyLookupError("Remote server returned an error: %s" % (e,)) - - assert isinstance(response, dict) - if response["server_name"] != server_name: - raise KeyLookupError( - "Expected a response for server %r not %r" - % (server_name, response["server_name"]) - ) - - response_keys = await self.process_v2_response( - from_server=server_name, - response_json=response, - time_added_ms=time_now_ms, + time_now_ms = self.clock.time_msec() + try: + response = await self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server", + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, ) - await self.store.store_server_verify_keys( - server_name, - time_now_ms, - ((server_name, key_id, key) for key_id, key in response_keys.items()), + except (NotRetryingDestination, RequestSendFailed) as e: + # these both have str() representations which we can't really improve + # upon + raise KeyLookupError(str(e)) + except HttpResponseException as e: + raise KeyLookupError("Remote server returned an error: %s" % (e,)) + + assert isinstance(response, dict) + if response["server_name"] != server_name: + raise KeyLookupError( + "Expected a response for server %r not %r" + % (server_name, response["server_name"]) ) - keys.update(response_keys) - return keys + return await self.process_v2_response( + from_server=server_name, + response_json=response, + time_added_ms=time_now_ms, + ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 63628aa6b0..f7c309cad0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -433,7 +433,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) - self.assertEqual(path, "/_matrix/key/v2/server/key1") + self.assertEqual(path, "/_matrix/key/v2/server") return response self.http_client.get_json.side_effect = get_json @@ -469,18 +469,6 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) - def test_keyid_containing_forward_slash(self) -> None: - """We should url-encode any url unsafe chars in key ids. - - Detects https://github.com/matrix-org/synapse/issues/14488. - """ - fetcher = ServerKeyFetcher(self.hs) - self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) - - self.http_client.get_json.assert_called_once() - args, kwargs = self.http_client.get_json.call_args - self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") - class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 7f1fba1086..2bb6e27d94 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import urllib.parse from io import BytesIO, StringIO from typing import Any, Dict, Optional, Union from unittest.mock import Mock @@ -65,9 +64,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) - self.assertEqual( - path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) - ) + self.assertEqual(path, "/_matrix/key/v2/server") response = { "server_name": server_name, -- cgit 1.5.1 From 4569eda94423a10abb69e0f4d5f37eb723ed764b Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 30 Nov 2022 13:39:47 +0100 Subject: Use servers list approx to send read receipts when in partial state (#14549) Signed-off-by: Mathieu Velten --- changelog.d/14549.misc | 1 + synapse/federation/sender/__init__.py | 2 +- tests/federation/test_federation_sender.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14549.misc (limited to 'tests') diff --git a/changelog.d/14549.misc b/changelog.d/14549.misc new file mode 100644 index 0000000000..d9d863dd20 --- /dev/null +++ b/changelog.d/14549.misc @@ -0,0 +1 @@ +Faster joins: use servers list approximation to send read receipts when in partial state instead of waiting for the full state of the room. \ No newline at end of file diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index fc1d8c88a7..30ebd62883 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -647,7 +647,7 @@ class FederationSender(AbstractFederationSender): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains_set = await self._storage_controllers.state.get_current_hosts_in_room( + domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) domains = [ diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 01f147418b..cbc99d30b9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -38,6 +38,10 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): return_value=make_awaitable({"test", "host2"}) ) + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( + hs.get_storage_controllers().state.get_current_hosts_in_room + ) + return hs @override_config({"send_federation": True}) -- cgit 1.5.1 From e8bce8999f21d30affc459755e304a1f4732165c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 30 Nov 2022 13:45:06 +0000 Subject: Aggregate unread notif count query for badge count calculation (#14255) Fetch the unread notification counts used by the badge counts in push notifications for all rooms at once (instead of fetching them per room). --- changelog.d/14255.misc | 1 + synapse/push/push_tools.py | 28 ++-- .../storage/databases/main/event_push_actions.py | 149 +++++++++++++++++++++ tests/storage/test_event_push_actions.py | 47 +++++-- 4 files changed, 198 insertions(+), 27 deletions(-) create mode 100644 changelog.d/14255.misc (limited to 'tests') diff --git a/changelog.d/14255.misc b/changelog.d/14255.misc new file mode 100644 index 0000000000..39924659c7 --- /dev/null +++ b/changelog.d/14255.misc @@ -0,0 +1 @@ +Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index edeba27a45..7ee07e4bee 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,7 +17,6 @@ from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.util.async_helpers import concurrently_execute async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: @@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge = len(invites) - room_notifs = [] - - async def get_room_unread_count(room_id: str) -> None: - room_notifs.append( - await store.get_unread_event_push_actions_by_room_for_user( - room_id, - user_id, - ) - ) - - await concurrently_execute(get_room_unread_count, joins, 10) - - for notifs in room_notifs: - # Combine the counts from all the threads. - notify_count = notifs.main_timeline.notify_count + sum( - n.notify_count for n in notifs.threads.values() - ) + room_to_count = await store.get_unread_counts_by_room_for_user(user_id) + for room_id, notify_count in room_to_count.items(): + # room_to_count may include rooms which the user has left, + # ignore those. + if room_id not in joins: + continue if notify_count == 0: continue @@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - # return one badge count per conversation badge += 1 else: - # increment the badge count by the number of unread messages in the room + # Increase badge by number of notifications in room + # NOTE: this includes threaded and unthreaded notifications. badge += notify_count + return badge diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b283ab0f9c..7ebe34f773 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -74,6 +74,7 @@ receipt. """ import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, Collection, @@ -95,6 +96,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + PostgresEngine, ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore @@ -463,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result + async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]: + """Get the notification count by room for a user. Only considers notifications, + not highlight or unread counts, and threads are currently aggregated under their room. + + This function is intentionally not cached because it is called to calculate the + unread badge for push notifications and thus the result is expected to change. + + Note that this function assumes the user is a member of the room. Because + summary rows are not removed when a user leaves a room, the caller must + filter out those results from the result. + + Returns: + A map of room ID to notification counts for the given user. + """ + return await self.db_pool.runInteraction( + "get_unread_counts_by_room_for_user", + self._get_unread_counts_by_room_for_user_txn, + user_id, + ) + + def _get_unread_counts_by_room_for_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Dict[str, int]: + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + args.extend([user_id, user_id]) + + receipts_cte = f""" + WITH all_receipts AS ( + SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + {receipt_types_clause} + AND user_id = ? + GROUP BY room_id, thread_id + ) + """ + + receipts_joins = """ + LEFT JOIN ( + SELECT room_id, thread_id, + max_receipt_stream_ordering AS threaded_receipt_stream_ordering + FROM all_receipts + WHERE thread_id IS NOT NULL + ) AS threaded_receipts USING (room_id, thread_id) + LEFT JOIN ( + SELECT room_id, thread_id, + max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering + FROM all_receipts + WHERE thread_id IS NULL + ) AS unthreaded_receipts USING (room_id) + """ + + # First get summary counts by room / thread for the user. We use the max receipt + # stream ordering of both threaded & unthreaded receipts to compare against the + # summary table. + # + # PostgreSQL and SQLite differ in comparing scalar numerics. + if isinstance(self.database_engine, PostgresEngine): + # GREATEST ignores NULLs. + max_clause = """GREATEST( + threaded_receipt_stream_ordering, + unthreaded_receipt_stream_ordering + )""" + else: + # MAX returns NULL if any are NULL, so COALESCE to 0 first. + max_clause = """MAX( + COALESCE(threaded_receipt_stream_ordering, 0), + COALESCE(unthreaded_receipt_stream_ordering, 0) + )""" + + sql = f""" + {receipts_cte} + SELECT eps.room_id, eps.thread_id, notif_count + FROM event_push_summary AS eps + {receipts_joins} + WHERE user_id = ? + AND notif_count != 0 + AND ( + (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause}) + OR last_receipt_stream_ordering = {max_clause} + ) + """ + txn.execute(sql, args) + + seen_thread_ids = set() + room_to_count: Dict[str, int] = defaultdict(int) + + for room_id, thread_id, notif_count in txn: + room_to_count[room_id] += notif_count + seen_thread_ids.add(thread_id) + + # Now get any event push actions that haven't been rotated using the same OR + # join and filter by receipt and event push summary rotated up to stream ordering. + sql = f""" + {receipts_cte} + SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count + FROM event_push_actions AS epa + {receipts_joins} + WHERE user_id = ? + AND epa.notif = 1 + AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering) + AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) + AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) + GROUP BY epa.room_id, epa.thread_id + """ + txn.execute(sql, args) + + for room_id, thread_id, notif_count in txn: + # Note: only count push actions we have valid summaries for with up to date receipt. + if thread_id not in seen_thread_ids: + continue + room_to_count[room_id] += notif_count + + thread_id_clause, thread_ids_args = make_in_list_sql_clause( + self.database_engine, "epa.thread_id", seen_thread_ids + ) + + # Finally re-check event_push_actions for any rooms not in the summary, ignoring + # the rotated up-to position. This handles the case where a read receipt has arrived + # but not been rotated meaning the summary table is out of date, so we go back to + # the push actions table. + sql = f""" + {receipts_cte} + SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count + FROM event_push_actions AS epa + {receipts_joins} + WHERE user_id = ? + AND NOT {thread_id_clause} + AND epa.notif = 1 + AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) + AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) + GROUP BY epa.room_id + """ + + args.extend(thread_ids_args) + txn.execute(sql, args) + + for room_id, notif_count in txn: + room_to_count[room_id] += notif_count + + return room_to_count + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index ee48920f84..5fa8bd2d98 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -156,7 +156,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str - def _assert_counts(noitf_count: int, highlight_count: int) -> None: + def _assert_counts(notif_count: int, highlight_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", @@ -168,13 +168,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, + notify_count=notif_count, unread_count=0, highlight_count=highlight_count, ), ) self.assertEqual(counts.threads, {}) + aggregate_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-aggregate-unread-counts", + self.store._get_unread_counts_by_room_for_user_txn, + user_id, + ) + ) + self.assertEqual(aggregate_counts[room_id], notif_count) + def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( room_id, @@ -283,7 +292,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str def _assert_counts( - noitf_count: int, + notif_count: int, highlight_count: int, thread_notif_count: int, thread_highlight_count: int, @@ -299,7 +308,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, + notify_count=notif_count, unread_count=0, highlight_count=highlight_count, ), @@ -318,6 +327,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): else: self.assertEqual(counts.threads, {}) + aggregate_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-aggregate-unread-counts", + self.store._get_unread_counts_by_room_for_user_txn, + user_id, + ) + ) + self.assertEqual( + aggregate_counts[room_id], notif_count + thread_notif_count + ) + def _create_event( highlight: bool = False, thread_id: Optional[str] = None ) -> str: @@ -454,7 +474,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str def _assert_counts( - noitf_count: int, + notif_count: int, highlight_count: int, thread_notif_count: int, thread_highlight_count: int, @@ -470,7 +490,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, + notify_count=notif_count, unread_count=0, highlight_count=highlight_count, ), @@ -489,6 +509,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): else: self.assertEqual(counts.threads, {}) + aggregate_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-aggregate-unread-counts", + self.store._get_unread_counts_by_room_for_user_txn, + user_id, + ) + ) + self.assertEqual( + aggregate_counts[room_id], notif_count + thread_notif_count + ) + def _create_event( highlight: bool = False, thread_id: Optional[str] = None ) -> str: @@ -646,7 +677,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) return result["event_id"] - def _assert_counts(noitf_count: int, thread_notif_count: int) -> None: + def _assert_counts(notif_count: int, thread_notif_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", @@ -658,7 +689,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, unread_count=0, highlight_count=0 + notify_count=notif_count, unread_count=0, highlight_count=0 ), ) if thread_notif_count: -- cgit 1.5.1 From 854a6884d81c95297bf93badcddc00a4cab93418 Mon Sep 17 00:00:00 2001 From: realtyem Date: Thu, 1 Dec 2022 06:38:27 -0600 Subject: Modernize unit tests configuration settings for workers. (#14568) Use the newer foo_instances configuration instead of the deprecated flags to enable specific features (e.g. start_pushers). --- changelog.d/14568.misc | 1 + tests/events/test_presence_router.py | 16 ++++-- tests/federation/test_federation_catch_up.py | 21 +++++--- tests/federation/test_federation_sender.py | 27 +++++++++-- tests/handlers/test_presence.py | 3 +- tests/handlers/test_typing.py | 6 ++- tests/handlers/test_user_directory.py | 7 ++- tests/module_api/test_api.py | 3 +- tests/push/test_email.py | 1 - tests/push/test_http.py | 7 +-- tests/replication/_base.py | 2 +- tests/replication/tcp/streams/test_federation.py | 5 +- tests/replication/test_auth.py | 4 +- tests/replication/test_client_reader_shard.py | 14 +++--- tests/replication/test_federation_ack.py | 5 +- tests/replication/test_federation_sender_shard.py | 59 ++++++++++++++--------- tests/replication/test_pusher_shard.py | 15 ++---- tests/utils.py | 8 +-- 18 files changed, 123 insertions(+), 81 deletions(-) create mode 100644 changelog.d/14568.misc (limited to 'tests') diff --git a/changelog.d/14568.misc b/changelog.d/14568.misc new file mode 100644 index 0000000000..99973de1c1 --- /dev/null +++ b/changelog.d/14568.misc @@ -0,0 +1 @@ +Modernize unit tests configuration related to workers. diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 685a9a6d52..b703e4472e 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -126,6 +126,13 @@ class PresenceRouterTestModule: class PresenceRouterTestCase(FederatingHomeserverTestCase): + """ + Test cases using a custom PresenceRouter + + By default in test cases, federation sending is disabled. This class re-enables it + for the main process by setting `federation_sender_instances` to None. + """ + servlets = [ admin.register_servlets, login.register_servlets, @@ -150,6 +157,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): self.sync_handler = self.hs.get_sync_handler() self.module_api = homeserver.get_module_api() + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + @override_config( { "presence": { @@ -162,7 +174,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } }, - "send_federation": True, } ) def test_receiving_all_presence_legacy(self): @@ -180,7 +191,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, }, ], - "send_federation": True, } ) def test_receiving_all_presence(self): @@ -290,7 +300,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } }, - "send_federation": True, } ) def test_send_local_online_presence_to_with_module_legacy(self): @@ -310,7 +319,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, }, ], - "send_federation": True, } ) def test_send_local_online_presence_to_with_module(self): diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 2873b4d430..b8fee72898 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -7,13 +7,21 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client import login, room +from synapse.types import JsonDict from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable -from tests.unittest import FederatingHomeserverTestCase, override_config +from tests.unittest import FederatingHomeserverTestCase class FederationCatchUpTestCases(FederatingHomeserverTestCase): + """ + Tests cases of catching up over federation. + + By default for test cases federation sending is disabled. This Test class has it + re-enabled for the main process. + """ + servlets = [ admin.register_servlets, room.register_servlets, @@ -42,6 +50,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.record_transaction ) + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + async def record_transaction(self, txn, json_cb): if self.is_online: data = json_cb() @@ -79,7 +92,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): )[0] return {"event_id": event_id, "stream_ordering": stream_ordering} - @override_config({"send_federation": True}) def test_catch_up_destination_rooms_tracking(self): """ Tests that we populate the `destination_rooms` table as needed. @@ -105,7 +117,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.assertEqual(row_2["event_id"], event_id_2) self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) - @override_config({"send_federation": True}) def test_catch_up_last_successful_stream_ordering_tracking(self): """ Tests that we populate the `destination_rooms` table as needed. @@ -163,7 +174,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): "Send succeeded but not marked as last_successful_stream_ordering", ) - @override_config({"send_federation": True}) # critical to federate def test_catch_up_from_blank_state(self): """ Runs an overall test of federation catch-up from scratch. @@ -260,7 +270,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): return per_dest_queue, results_list - @override_config({"send_federation": True}) def test_catch_up_loop(self): """ Tests the behaviour of _catch_up_transmission_loop. @@ -325,7 +334,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_5.internal_metadata.stream_ordering, ) - @override_config({"send_federation": True}) def test_catch_up_on_synapse_startup(self): """ Tests the behaviour of get_catch_up_outstanding_destinations and @@ -424,7 +432,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - all destinations are woken exactly once; they appear once in woken. self.assertCountEqual(woken, server_names[:-1]) - @override_config({"send_federation": True}) def test_not_latest_event(self): """Test that we send the latest event in the room even if its not ours.""" diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cbc99d30b9..8692d8190f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -25,10 +25,17 @@ from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt from tests.test_utils import make_awaitable -from tests.unittest import HomeserverTestCase, override_config +from tests.unittest import HomeserverTestCase class FederationSenderReceiptsTestCases(HomeserverTestCase): + """ + Test federation sending to update receipts. + + By default for test cases federation sending is disabled. This Test class has it + re-enabled for the main process. + """ + def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), @@ -44,7 +51,11 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): return hs - @override_config({"send_federation": True}) + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + def test_send_receipts(self): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction @@ -87,7 +98,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ], ) - @override_config({"send_federation": True}) def test_send_receipts_thread(self): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction @@ -164,7 +174,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ], ) - @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" @@ -251,6 +260,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): class FederationSenderDevicesTestCases(HomeserverTestCase): + """ + Test federation sending to update devices. + + By default for test cases federation sending is disabled. This Test class has it + re-enabled for the main process. + """ + servlets = [ admin.register_servlets, login.register_servlets, @@ -265,7 +281,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def default_config(self): c = super().default_config() - c["send_federation"] = True + # Enable federation sending on the main process. + c["federation_sender_instances"] = None return c def prepare(self, reactor, clock, hs): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c5981ff965..584e7b8971 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -992,7 +992,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def default_config(self): config = super().default_config() - config["send_federation"] = True + # Enable federation sending on the main process. + config["federation_sender_instances"] = None return config def prepare(self, reactor, clock, hs): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 9c821b3042..efbb5a8dbb 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -200,7 +200,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - @override_config({"send_federation": True}) + # Enable federation sending on the main process. + @override_config({"federation_sender_instances": None}) def test_started_typing_remote_send(self) -> None: self.room_members = [U_APPLE, U_ONION] @@ -305,7 +306,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0], []) self.assertEqual(events[1], 0) - @override_config({"send_federation": True}) + # Enable federation sending on the main process. + @override_config({"federation_sender_instances": None}) def test_stopped_typing(self) -> None: self.room_members = [U_APPLE, U_BANANA, U_ONION] diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9e39cd97e5..75fc5a17a4 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -56,7 +56,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True + # Re-enables updating the user directory, as that function is needed below. + config["update_user_directory_from_worker"] = None self.appservice = ApplicationService( token="i_am_an_app_service", @@ -1045,7 +1046,9 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True + # Re-enables updating the user directory, as that function is needed below. It + # will be force disabled later + config["update_user_directory_from_worker"] = None hs = self.setup_test_homeserver(config=config) self.config = hs.config diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 058ca57e55..b0f3f4374d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -336,7 +336,8 @@ class ModuleApiTestCase(HomeserverTestCase): # Test sending local online presence to users from the main process _test_sending_local_online_presence_to_local_user(self, test_with_workers=False) - @override_config({"send_federation": True}) + # Enable federation sending on the main process. + @override_config({"federation_sender_instances": None}) def test_send_local_online_presence_to_federation(self): """Tests that send_local_presence_to_users sends local online presence to remote users.""" # Create a user who will send presence updates diff --git a/tests/push/test_email.py b/tests/push/test_email.py index fd14568f55..57b2f0536e 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -66,7 +66,6 @@ class EmailPusherTests(HomeserverTestCase): "riot_base_url": None, } config["public_baseurl"] = "http://aaa" - config["start_pushers"] = True hs = self.setup_test_homeserver(config=config) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index b383b8401f..afaafe79aa 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -41,11 +41,6 @@ class HTTPPusherTests(HomeserverTestCase): user_id = True hijack_auth = False - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["start_pushers"] = True - return config - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.push_attempts: List[Tuple[Deferred, str, dict]] = [] diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 3029a16dda..6a7174b333 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -307,7 +307,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): stream to the master HS. Args: - worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + worker_app: Type of worker, e.g. `synapse.app.generic_worker`. extra_config: Any extra config to use for this instances. **kwargs: Options that get passed to `self.setup_test_homeserver`, useful to e.g. pass some mocks for things like `federation_http_client` diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index ffec06a0d6..bcb82c9c80 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -22,9 +22,8 @@ class FederationStreamTestCase(BaseStreamTestCase): def _get_worker_hs_config(self) -> dict: # enable federation sending on the worker config = super()._get_worker_hs_config() - # TODO: make it so we don't need both of these - config["send_federation"] = False - config["worker_app"] = "synapse.app.federation_sender" + config["worker_name"] = "federation_sender1" + config["federation_sender_instances"] = ["federation_sender1"] return config def test_catchup(self): diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 43a16bb141..5d7a89e0c7 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -38,7 +38,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def _get_worker_hs_config(self) -> dict: config = self.default_config() - config["worker_app"] = "synapse.app.client_reader" + config["worker_app"] = "synapse.app.generic_worker" config["worker_replication_host"] = "testserv" config["worker_replication_http_port"] = "8765" @@ -53,7 +53,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): 4. Return the final request. """ - worker_hs = self.make_worker_hs("synapse.app.client_reader") + worker_hs = self.make_worker_hs("synapse.app.generic_worker") site = self._hs_to_site[worker_hs] channel_1 = make_request( diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 995097d72c..eb5b376534 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -22,20 +22,20 @@ logger = logging.getLogger(__name__) class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): - """Test using one or more client readers for registration.""" + """Test using one or more generic workers for registration.""" servlets = [register.register_servlets] def _get_worker_hs_config(self) -> dict: config = self.default_config() - config["worker_app"] = "synapse.app.client_reader" + config["worker_app"] = "synapse.app.generic_worker" config["worker_replication_host"] = "testserv" config["worker_replication_http_port"] = "8765" return config def test_register_single_worker(self): - """Test that registration works when using a single client reader worker.""" - worker_hs = self.make_worker_hs("synapse.app.client_reader") + """Test that registration works when using a single generic worker.""" + worker_hs = self.make_worker_hs("synapse.app.generic_worker") site = self._hs_to_site[worker_hs] channel_1 = make_request( @@ -64,9 +64,9 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(channel_2.json_body["user_id"], "@user:test") def test_register_multi_worker(self): - """Test that registration works when using multiple client reader workers.""" - worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") - worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") + """Test that registration works when using multiple generic workers.""" + worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker") + worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker") site_1 = self._hs_to_site[worker_hs_1] channel_1 = make_request( diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 26b8bd512a..63b1dd40b5 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -25,8 +25,9 @@ from tests.unittest import HomeserverTestCase class FederationAckTestCase(HomeserverTestCase): def default_config(self) -> dict: config = super().default_config() - config["worker_app"] = "synapse.app.federation_sender" - config["send_federation"] = False + config["worker_app"] = "synapse.app.generic_worker" + config["worker_name"] = "federation_sender1" + config["federation_sender_instances"] = ["federation_sender1"] return config def make_homeserver(self, reactor, clock): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 6104a55aa1..c28073b8f7 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -27,17 +27,19 @@ logger = logging.getLogger(__name__) class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): + """ + Various tests for federation sending on workers. + + Federation sending is disabled by default, it will be enabled in each test by + updating 'federation_sender_instances'. + """ + servlets = [ login.register_servlets, register_servlets_for_client_rest_resource, room.register_servlets, ] - def default_config(self): - conf = super().default_config() - conf["send_federation"] = False - return conf - def test_send_event_single_sender(self): """Test that using a single federation sender worker correctly sends a new event. @@ -46,8 +48,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", - {"send_federation": False}, + "synapse.app.generic_worker", + { + "worker_name": "federation_sender1", + "federation_sender_instances": ["federation_sender1"], + }, federation_http_client=mock_client, ) @@ -73,11 +78,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client1 = Mock(spec=["put_json"]) mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender1", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender1", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client1, ) @@ -85,11 +92,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client2 = Mock(spec=["put_json"]) mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender2", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender2", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client2, ) @@ -136,11 +145,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client1 = Mock(spec=["put_json"]) mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender1", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender1", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client1, ) @@ -148,11 +159,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): mock_client2 = Mock(spec=["put_json"]) mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( - "synapse.app.federation_sender", + "synapse.app.generic_worker", { - "send_federation": True, - "worker_name": "sender2", - "federation_sender_instances": ["sender1", "sender2"], + "worker_name": "federation_sender2", + "federation_sender_instances": [ + "federation_sender1", + "federation_sender2", + ], }, federation_http_client=mock_client2, ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 59fea93e49..ca18ad6553 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -38,11 +38,6 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.other_user_id = self.register_user("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass") - def default_config(self): - conf = super().default_config() - conf["start_pushers"] = False - return conf - def _create_pusher_and_send_msg(self, localpart): # Create a user that will get push notifications user_id = self.register_user(localpart, "pass") @@ -92,8 +87,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): ) self.make_worker_hs( - "synapse.app.pusher", - {"start_pushers": False}, + "synapse.app.generic_worker", + {"worker_name": "pusher1", "pusher_instances": ["pusher1"]}, proxied_blacklisted_http_client=http_client_mock, ) @@ -122,9 +117,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): ) self.make_worker_hs( - "synapse.app.pusher", + "synapse.app.generic_worker", { - "start_pushers": True, "worker_name": "pusher1", "pusher_instances": ["pusher1", "pusher2"], }, @@ -137,9 +131,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): ) self.make_worker_hs( - "synapse.app.pusher", + "synapse.app.generic_worker", { - "start_pushers": True, "worker_name": "pusher2", "pusher_instances": ["pusher1", "pusher2"], }, diff --git a/tests/utils.py b/tests/utils.py index 045a8b5fa7..d76bf9716a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -125,7 +125,8 @@ def default_config( """ config_dict = { "server_name": name, - "send_federation": False, + # Setting this to an empty list turns off federation sending. + "federation_sender_instances": [], "media_store_path": "media", # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy @@ -183,8 +184,9 @@ def default_config( # rooms will fail. "default_room_version": DEFAULT_ROOM_VERSION, # disable user directory updates, because they get done in the - # background, which upsets the test runner. - "update_user_directory": False, + # background, which upsets the test runner. Setting this to an + # (obviously) fake worker name disables updating the user directory. + "update_user_directory_from_worker": "does_not_exist_worker_name", "caches": {"global_factor": 1, "sync_response_cache_duration": 0}, "listeners": [{"port": 0, "type": "http"}], } -- cgit 1.5.1 From 71f3e53ad010ba8c219f1076d40915b985760ed9 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 1 Dec 2022 13:46:24 +0000 Subject: Add `push.enabled` option to disable push notification calculation (#14551) * Add initial option * changelog * Some more linting --- changelog.d/14551.feature | 1 + docs/usage/configuration/config_documentation.md | 5 +++ synapse/config/push.py | 1 + synapse/push/bulk_push_rule_evaluator.py | 3 ++ tests/push/test_bulk_push_rule_evaluator.py | 45 ++++++++++++++++++++++-- 5 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14551.feature (limited to 'tests') diff --git a/changelog.d/14551.feature b/changelog.d/14551.feature new file mode 100644 index 0000000000..43b91d2e57 --- /dev/null +++ b/changelog.d/14551.feature @@ -0,0 +1 @@ +Add new `push.enabled` config option to allow opting out of push notification calculation. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 749af12aac..b9bde8f47e 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3355,6 +3355,10 @@ Configuration settings related to push notifications This setting defines options for push notifications. This option has a number of sub-options. They are as follows: +* `enable_push`: Enables or disables push notification calculation. Note, disabling this will also + stop unread counts being calculated for rooms. This mode of operation is intended + for homeservers which may only have bots or appservice users connected, or are otherwise + not interested in push/unread counters. This is enabled by default. * `include_content`: Clients requesting push notifications can either have the body of the message sent in the notification poke along with other details like the sender, or just the event ID and room ID (`event_id_only`). @@ -3375,6 +3379,7 @@ This option has a number of sub-options. They are as follows: Example configuration: ```yaml push: + enable_push: true include_content: false group_unread_count_by_room: false ``` diff --git a/synapse/config/push.py b/synapse/config/push.py index 979b128eae..3b5378e6ea 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -26,6 +26,7 @@ class PushConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: push_config = config.get("push") or {} self.push_include_content = push_config.get("include_content", True) + self.enable_push = push_config.get("enabled", True) self.push_group_unread_count_by_room = push_config.get( "group_unread_count_by_room", True ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d6b377860f..9ed35d8461 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -106,6 +106,7 @@ class BulkPushRuleEvaluator: self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self.should_calculate_push_rules = self.hs.config.push.enable_push self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled @@ -269,6 +270,8 @@ class BulkPushRuleEvaluator: for each event, check if the message should increment the unread count, and insert the results into the event_push_actions_staging table. """ + if not self.should_calculate_push_rules: + return # For batched events the power level events may not have been persisted yet, # so we pass in the batched events. Thus if the event cannot be found in the # database we can check in the batch. diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 594e7937a8..1cd453248e 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -6,10 +6,11 @@ from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.types import create_requester -from tests import unittest +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase, override_config -class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): +class TestBulkPushRuleEvaluator(HomeserverTestCase): servlets = [ admin.register_servlets_for_client_rest_resource, @@ -72,3 +73,43 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + + @override_config({"push": {"enabled": False}}) + def test_action_for_event_by_user_disabled_by_config(self) -> None: + """Ensure that push rules are not calculated when disabled in the config""" + # Create a new user and room. + alice = self.register_user("alice", "pass") + token = self.login(alice, "pass") + + room_id = self.helper.create_room_as( + alice, room_version=RoomVersions.V9.identifier, tok=token + ) + + # Alter the power levels in that room to include stringy and floaty levels. + # We need to suppress the validation logic or else it will reject these dodgy + # values. (Presumably this validation was not always present.) + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(alice) + + # Create a new message event, and try to evaluate it under the dodgy + # power level event. + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.message", + "room_id": room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": alice, + }, + ) + ) + + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] + # should not raise + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + bulk_evaluator._action_for_event_by_user.assert_not_called() -- cgit 1.5.1 From acea4d7a2ff61b5beda420b54a8451088060a8cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 2 Dec 2022 12:58:56 -0500 Subject: Add missing types to tests.util. (#14597) Removes files under tests.util from the ignored by list, then fully types all tests/util/*.py files. --- changelog.d/14597.misc | 1 + mypy.ini | 13 +--- tests/util/test_async_helpers.py | 118 ++++++++++++++++++--------------- tests/util/test_batching_queue.py | 30 +++++---- tests/util/test_check_dependencies.py | 29 ++++++-- tests/util/test_dict_cache.py | 20 +++--- tests/util/test_expiring_cache.py | 26 +++++--- tests/util/test_file_consumer.py | 103 ++++++++++++++++------------ tests/util/test_itertools.py | 24 +++---- tests/util/test_logcontext.py | 86 +++++++++++++++--------- tests/util/test_logformatter.py | 2 +- tests/util/test_lrucache.py | 80 +++++++++++----------- tests/util/test_macaroons.py | 8 +-- tests/util/test_ratelimitutils.py | 15 +++-- tests/util/test_retryutils.py | 4 +- tests/util/test_rwlock.py | 14 ++-- tests/util/test_stream_change_cache.py | 14 ++-- tests/util/test_stringutils.py | 4 +- tests/util/test_threepids.py | 16 ++--- tests/util/test_treecache.py | 14 ++-- tests/util/test_wheel_timer.py | 16 ++--- 21 files changed, 361 insertions(+), 276 deletions(-) create mode 100644 changelog.d/14597.misc (limited to 'tests') diff --git a/changelog.d/14597.misc b/changelog.d/14597.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14597.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 0b6e7df267..c3fbd1a955 100644 --- a/mypy.ini +++ b/mypy.ini @@ -59,16 +59,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/test_state.py |tests/test_terms_auth.py - |tests/util/test_async_helpers.py - |tests/util/test_batching_queue.py - |tests/util/test_dict_cache.py - |tests/util/test_expiring_cache.py - |tests/util/test_file_consumer.py - |tests/util/test_linearizer.py - |tests/util/test_logcontext.py - |tests/util/test_lrucache.py - |tests/util/test_rwlock.py - |tests/util/test_wheel_timer.py )$ [mypy-synapse.federation.transport.client] @@ -137,6 +127,9 @@ disallow_untyped_defs = True [mypy-tests.util.caches.test_descriptors] disallow_untyped_defs = False +[mypy-tests.util.*] +disallow_untyped_defs = True + [mypy-tests.utils] disallow_untyped_defs = True diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 9d5010bf92..91cac9822a 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import traceback +from typing import Generator, List, NoReturn, Optional from parameterized import parameterized_class @@ -41,8 +42,8 @@ from tests.unittest import TestCase class ObservableDeferredTest(TestCase): - def test_succeed(self): - origin_d = Deferred() + def test_succeed(self) -> None: + origin_d: "Deferred[int]" = Deferred() observable = ObservableDeferred(origin_d) observer1 = observable.observe() @@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase): self.assertFalse(observer2.called) # check the first observer is called first - def check_called_first(res): + def check_called_first(res: int) -> int: self.assertFalse(observer2.called) return res observer1.addBoth(check_called_first) # store the results - results = [None, None] + results: List[Optional[ObservableDeferred[int]]] = [None, None] - def check_val(res, idx): + def check_val( + res: ObservableDeferred[int], idx: int + ) -> ObservableDeferred[int]: results[idx] = res return res @@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase): self.assertEqual(results[0], 123, "observer 1 callback result") self.assertEqual(results[1], 123, "observer 2 callback result") - def test_failure(self): - origin_d = Deferred() + def test_failure(self) -> None: + origin_d: Deferred = Deferred() observable = ObservableDeferred(origin_d, consumeErrors=True) observer1 = observable.observe() @@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase): self.assertFalse(observer2.called) # check the first observer is called first - def check_called_first(res): + def check_called_first(res: int) -> int: self.assertFalse(observer2.called) return res observer1.addBoth(check_called_first) # store the results - results = [None, None] + results: List[Optional[ObservableDeferred[str]]] = [None, None] - def check_val(res, idx): + def check_val(res: ObservableDeferred[str], idx: int) -> None: results[idx] = res return None @@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase): raise Exception("gah!") except Exception as e: origin_d.errback(e) + assert results[0] is not None self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") + assert results[1] is not None self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") - def test_cancellation(self): + def test_cancellation(self) -> None: """Test that cancelling an observer does not affect other observers.""" origin_d: "Deferred[int]" = Deferred() observable = ObservableDeferred(origin_d, consumeErrors=True) @@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase): class TimeoutDeferredTest(TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = Clock() - def test_times_out(self): + def test_times_out(self) -> None: """Basic test case that checks that the original deferred is cancelled and that the timing-out deferred is errbacked """ - cancelled = [False] + cancelled = False - def canceller(_d): - cancelled[0] = True + def canceller(_d: Deferred) -> None: + nonlocal cancelled + cancelled = True - non_completing_d = Deferred(canceller) + non_completing_d: Deferred = Deferred(canceller) timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertFalse(cancelled[0], "deferred was cancelled prematurely") + self.assertFalse(cancelled, "deferred was cancelled prematurely") self.clock.pump((1.0,)) - self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") + self.assertTrue(cancelled, "deferred was not cancelled by timeout") self.failureResultOf(timing_out_d, defer.TimeoutError) - def test_times_out_when_canceller_throws(self): + def test_times_out_when_canceller_throws(self) -> None: """Test that we have successfully worked around https://twistedmatrix.com/trac/ticket/9534""" - def canceller(_d): + def canceller(_d: Deferred) -> None: raise Exception("can't cancel this deferred") - non_completing_d = Deferred(canceller) + non_completing_d: Deferred = Deferred(canceller) timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) self.assertNoResult(timing_out_d) @@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase): self.failureResultOf(timing_out_d, defer.TimeoutError) - def test_logcontext_is_preserved_on_cancellation(self): - blocking_was_cancelled = [False] + def test_logcontext_is_preserved_on_cancellation(self) -> None: + blocking_was_cancelled = False @defer.inlineCallbacks - def blocking(): - non_completing_d = Deferred() + def blocking() -> Generator["Deferred[object]", object, None]: + nonlocal blocking_was_cancelled + + non_completing_d: Deferred = Deferred() with PreserveLoggingContext(): try: yield non_completing_d except CancelledError: - blocking_was_cancelled[0] = True + blocking_was_cancelled = True raise with LoggingContext("one") as context_one: # the errbacks should be run in the test logcontext - def errback(res, deferred_name): + def errback(res: Failure, deferred_name: str) -> Failure: self.assertIs( current_context(), context_one, @@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase): self.clock.pump((1.0,)) self.assertTrue( - blocking_was_cancelled[0], "non-completing deferred was not cancelled" + blocking_was_cancelled, "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) @@ -220,13 +228,13 @@ class _TestException(Exception): class ConcurrentlyExecuteTest(TestCase): - def test_limits_runners(self): + def test_limits_runners(self) -> None: """If we have more tasks than runners, we should get the limit of runners""" started = 0 waiters = [] processed = [] - async def callback(v): + async def callback(v: int) -> None: # when we first enter, bump the start count nonlocal started started += 1 @@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase): processed.append(v) # wait for the goahead before returning - d2 = Deferred() + d2: "Deferred[int]" = Deferred() waiters.append(d2) await d2 @@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase): self.assertCountEqual(processed, [1, 2, 3, 4, 5]) self.successResultOf(d2) - def test_preserves_stacktraces(self): + def test_preserves_stacktraces(self) -> None: """Test that the stacktrace from an exception thrown in the callback is preserved""" - d1 = Deferred() + d1: "Deferred[int]" = Deferred() - async def callback(v): + async def callback(v: int) -> None: # alas, this doesn't work at all without an await here await d1 raise _TestException("bah") - async def caller(): + async def caller() -> None: try: await concurrently_execute(callback, [1], 2) except _TestException as e: @@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase): d1.callback(0) self.successResultOf(d2) - def test_preserves_stacktraces_on_preformed_failure(self): + def test_preserves_stacktraces_on_preformed_failure(self) -> None: """Test that the stacktrace on a Failure returned by the callback is preserved""" - d1 = Deferred() + d1: "Deferred[int]" = Deferred() f = Failure(_TestException("bah")) - async def callback(v): + async def callback(v: int) -> None: # alas, this doesn't work at all without an await here await d1 await defer.fail(f) - async def caller(): + async def caller() -> None: try: await concurrently_execute(callback, [1], 2) except _TestException as e: @@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase): else: raise ValueError(f"Unsupported wrapper type: {self.wrapper}") - def test_succeed(self): + def test_succeed(self) -> None: """Test that the new `Deferred` receives the result.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = self.wrap_deferred(deferred) @@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase): self.assertTrue(wrapper_deferred.called) self.assertEqual("success", self.successResultOf(wrapper_deferred)) - def test_failure(self): + def test_failure(self) -> None: """Test that the new `Deferred` receives the `Failure`.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = self.wrap_deferred(deferred) @@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase): class StopCancellationTests(TestCase): """Tests for the `stop_cancellation` function.""" - def test_cancellation(self): + def test_cancellation(self) -> None: """Test that cancellation of the new `Deferred` leaves the original running.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = stop_cancellation(deferred) @@ -384,7 +392,7 @@ class StopCancellationTests(TestCase): class DelayCancellationTests(TestCase): """Tests for the `delay_cancellation` function.""" - def test_deferred_cancellation(self): + def test_deferred_cancellation(self) -> None: """Test that cancellation of the new `Deferred` waits for the original.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = delay_cancellation(deferred) @@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase): # Now that the original `Deferred` has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) - def test_coroutine_cancellation(self): + def test_coroutine_cancellation(self) -> None: """Test that cancellation of the new `Deferred` waits for the original.""" blocking_deferred: "Deferred[None]" = Deferred() completion_deferred: "Deferred[None]" = Deferred() - async def task(): + async def task() -> NoReturn: await blocking_deferred completion_deferred.callback(None) # Raise an exception. Twisted should consume it, otherwise unwanted @@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase): # Now that the original coroutine has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) - def test_suppresses_second_cancellation(self): + def test_suppresses_second_cancellation(self) -> None: """Test that a second cancellation is suppressed. Identical to `test_cancellation` except the new `Deferred` is cancelled twice. @@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase): # Now that the original `Deferred` has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) - def test_propagates_cancelled_error(self): + def test_propagates_cancelled_error(self) -> None: """Test that a `CancelledError` from the original `Deferred` gets propagated.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = delay_cancellation(deferred) @@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase): self.assertTrue(wrapper_deferred.called) self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) - def test_preserves_logcontext(self): + def test_preserves_logcontext(self) -> None: """Test that logging contexts are preserved.""" blocking_d: "Deferred[None]" = Deferred() - async def inner(): + async def inner() -> None: await make_deferred_yieldable(blocking_d) - async def outer(): + async def outer() -> None: with LoggingContext("c") as c: try: await delay_cancellation(inner()) @@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase): class AwakenableSleeperTests(TestCase): "Tests AwakenableSleeper" - def test_sleep(self): + def test_sleep(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) @@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) self.assertTrue(d.called) - def test_explicit_wake(self): + def test_explicit_wake(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) @@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) - def test_multiple_sleepers_timeout(self): + def test_multiple_sleepers_timeout(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) @@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) self.assertTrue(d2.called) - def test_multiple_sleepers_wake(self): + def test_multiple_sleepers_wake(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 07be57d72c..94ef91f645 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + +from prometheus_client import Gauge + from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable @@ -26,7 +30,7 @@ from tests.unittest import TestCase class BatchingQueueTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.clock, hs_clock = get_clock() # We ensure that we remove any existing metrics for "test_queue". @@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase): except KeyError: pass - self._pending_calls = [] - self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) + self._pending_calls: List[Tuple[List[str], defer.Deferred]] = [] + self.queue: BatchingQueue[str, str] = BatchingQueue( + "test_queue", hs_clock, self._process_queue + ) - async def _process_queue(self, values): - d = defer.Deferred() + async def _process_queue(self, values: List[str]) -> str: + d: "defer.Deferred[str]" = defer.Deferred() self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) - def _get_sample_with_name(self, metric, name) -> int: + def _get_sample_with_name(self, metric: Gauge, name: str) -> float: """For a prometheus metric get the value of the sample that has a matching "name" label. """ - for sample in metric.collect()[0].samples: + for sample in next(iter(metric.collect())).samples: if sample.labels.get("name") == name: return sample.value self.fail("Found no matching sample") - def _assert_metrics(self, queued, keys, in_flight): + def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None: """Assert that the metrics are correct""" sample = self._get_sample_with_name(number_queued, self.queue._name) @@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase): "number_in_flight", ) - def test_simple(self): + def test_simple(self) -> None: """Tests the basic case of calling `add_to_queue` once and having `_process_queue` return. """ @@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase): self._assert_metrics(queued=0, keys=0, in_flight=0) - def test_batching(self): + def test_batching(self) -> None: """Test that multiple calls at the same time get batched up into one call to `_process_queue`. """ @@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d2), "bar") self._assert_metrics(queued=0, keys=0, in_flight=0) - def test_queuing(self): + def test_queuing(self) -> None: """Test that we queue up requests while a `_process_queue` is being called. """ @@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d3), "bar2") self._assert_metrics(queued=0, keys=0, in_flight=0) - def test_different_keys(self): + def test_different_keys(self) -> None: """Test that calls to different keys get processed in parallel.""" self.assertFalse(self._pending_calls) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 6913de24b9..aa20fe6780 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -1,5 +1,20 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import contextmanager -from typing import Generator, Optional +from os import PathLike +from typing import Generator, Optional, Union from unittest.mock import patch from synapse.util.check_dependencies import ( @@ -12,17 +27,17 @@ from tests.unittest import TestCase class DummyDistribution(metadata.Distribution): - def __init__(self, version: object): + def __init__(self, version: str): self._version = version @property - def version(self): + def version(self) -> str: return self._version - def locate_file(self, path): + def locate_file(self, path: Union[str, PathLike]) -> PathLike: raise NotImplementedError() - def read_text(self, filename): + def read_text(self, filename: str) -> None: raise NotImplementedError() @@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2") old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") new_release_candidate = DummyDistribution("1.2.3rc4") -distribution_with_no_version = DummyDistribution(None) +distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type] # could probably use stdlib TestCase --- no need for twisted here @@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase): If `distribution = None`, we pretend that the package is not installed. """ - def mock_distribution(name: str): + def mock_distribution(name: str) -> DummyDistribution: if distribution is None: raise metadata.PackageNotFoundError else: diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index e8b6246ab5..acb251bfea 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -19,10 +19,12 @@ from tests import unittest class DictCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = DictionaryCache("foobar", max_entries=10) + def setUp(self) -> None: + self.cache: DictionaryCache[str, str, str] = DictionaryCache( + "foobar", max_entries=10 + ) - def test_simple_cache_hit_full(self): + def test_simple_cache_hit_full(self) -> None: key = "test_simple_cache_hit_full" v = self.cache.get(key) @@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key) self.assertEqual(test_value, c.value) - def test_simple_cache_hit_partial(self): + def test_simple_cache_hit_partial(self) -> None: key = "test_simple_cache_hit_partial" seq = self.cache.sequence @@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key, ["test"]) self.assertEqual(test_value, c.value) - def test_simple_cache_miss_partial(self): + def test_simple_cache_miss_partial(self) -> None: key = "test_simple_cache_miss_partial" seq = self.cache.sequence @@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key, ["test2"]) self.assertEqual({}, c.value) - def test_simple_cache_hit_miss_partial(self): + def test_simple_cache_hit_miss_partial(self) -> None: key = "test_simple_cache_hit_miss_partial" seq = self.cache.sequence @@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key, ["test2"]) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) - def test_multi_insert(self): + def test_multi_insert(self) -> None: key = "test_simple_cache_hit_miss_partial" seq = self.cache.sequence @@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase): ) self.assertEqual(c.full, False) - def test_invalidation(self): + def test_invalidation(self) -> None: """Test that the partial dict and full dicts get invalidated separately. """ @@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase): # entry for "a" warm. for i in range(20): self.cache.get(key, ["a"]) - self.cache.update(seq, f"key{i}", {1: 2}) + self.cache.update(seq, f"key{i}", {"1": "2"}) # We should have evicted the full dict... r = self.cache.get(key) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 7f60aae5ba..9cf920daf8 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, cast +from synapse.util import Clock from synapse.util.caches.expiringcache import ExpiringCache from tests.utils import MockClock @@ -21,17 +23,21 @@ from .. import unittest class ExpiringCacheTestCase(unittest.HomeserverTestCase): - def test_get_set(self): + def test_get_set(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, max_len=1) + cache: ExpiringCache[str, str] = ExpiringCache( + "test", cast(Clock, clock), max_len=1 + ) cache["key"] = "value" self.assertEqual(cache.get("key"), "value") self.assertEqual(cache["key"], "value") - def test_eviction(self): + def test_eviction(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, max_len=2) + cache: ExpiringCache[str, str] = ExpiringCache( + "test", cast(Clock, clock), max_len=2 + ) cache["key"] = "value" cache["key2"] = "value2" @@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key2"), "value2") self.assertEqual(cache.get("key3"), "value3") - def test_iterable_eviction(self): + def test_iterable_eviction(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, max_len=5, iterable=True) + cache: ExpiringCache[str, List[int]] = ExpiringCache( + "test", cast(Clock, clock), max_len=5, iterable=True + ) cache["key"] = [1] cache["key2"] = [2, 3] @@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key3"), [4, 5]) self.assertEqual(cache.get("key4"), [6, 7]) - def test_time_eviction(self): + def test_time_eviction(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, expiry_ms=1000) + cache: ExpiringCache[str, int] = ExpiringCache( + "test", cast(Clock, clock), expiry_ms=1000 + ) cache["key"] = 1 clock.advance_time(0.5) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 3bb4695405..4f3c983c15 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -12,22 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading -from io import StringIO +from io import BytesIO +from typing import BinaryIO, Generator, Optional, cast from unittest.mock import NonCallableMock -from twisted.internet import defer, reactor +from zope.interface import implementer + +from twisted.internet import defer, reactor as _reactor +from twisted.internet.interfaces import IPullProducer +from synapse.types import ISynapseReactor from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest +reactor = cast(ISynapseReactor, _reactor) + class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks - def test_pull_consumer(self): - string_file = StringIO() + def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]: + string_file = BytesIO() consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: @@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase): yield producer.register_with_consumer(consumer) - yield producer.write_and_wait("Foo") + yield producer.write_and_wait(b"Foo") - self.assertEqual(string_file.getvalue(), "Foo") + self.assertEqual(string_file.getvalue(), b"Foo") - yield producer.write_and_wait("Bar") + yield producer.write_and_wait(b"Bar") - self.assertEqual(string_file.getvalue(), "FooBar") + self.assertEqual(string_file.getvalue(), b"FooBar") finally: consumer.unregisterProducer() - yield consumer.wait() + yield consumer.wait() # type: ignore[misc] self.assertTrue(string_file.closed) @defer.inlineCallbacks - def test_push_consumer(self): - string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file, reactor=reactor) + def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]: + string_file = BlockingBytesWrite() + consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor) try: producer = NonCallableMock(spec_set=[]) consumer.registerProducer(producer, True) - consumer.write("Foo") - yield string_file.wait_for_n_writes(1) + consumer.write(b"Foo") + yield string_file.wait_for_n_writes(1) # type: ignore[misc] - self.assertEqual(string_file.buffer, "Foo") + self.assertEqual(string_file.buffer, b"Foo") - consumer.write("Bar") - yield string_file.wait_for_n_writes(2) + consumer.write(b"Bar") + yield string_file.wait_for_n_writes(2) # type: ignore[misc] - self.assertEqual(string_file.buffer, "FooBar") + self.assertEqual(string_file.buffer, b"FooBar") finally: consumer.unregisterProducer() - yield consumer.wait() + yield consumer.wait() # type: ignore[misc] self.assertTrue(string_file.closed) @defer.inlineCallbacks - def test_push_producer_feedback(self): - string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file, reactor=reactor) + def test_push_producer_feedback( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + string_file = BlockingBytesWrite() + consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor) try: producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) - resume_deferred = defer.Deferred() + resume_deferred: defer.Deferred = defer.Deferred() producer.resumeProducing.side_effect = lambda: resume_deferred.callback( None ) @@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase): number_writes = 0 with string_file.write_lock: for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): - consumer.write("Foo") + consumer.write(b"Foo") number_writes += 1 producer.pauseProducing.assert_called_once() - yield string_file.wait_for_n_writes(number_writes) + yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc] yield resume_deferred producer.resumeProducing.assert_called_once() finally: consumer.unregisterProducer() - yield consumer.wait() + yield consumer.wait() # type: ignore[misc] self.assertTrue(string_file.closed) +@implementer(IPullProducer) class DummyPullProducer: - def __init__(self): - self.consumer = None - self.deferred = defer.Deferred() + def __init__(self) -> None: + self.consumer: Optional[BackgroundFileConsumer] = None + self.deferred: "defer.Deferred[object]" = defer.Deferred() - def resumeProducing(self): + def resumeProducing(self) -> None: d = self.deferred self.deferred = defer.Deferred() d.callback(None) - def write_and_wait(self, bytes): + def stopProducing(self) -> None: + raise RuntimeError("Unexpected call") + + def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]": + assert self.consumer is not None d = self.deferred - self.consumer.write(bytes) + self.consumer.write(write_bytes) return d - def register_with_consumer(self, consumer): + def register_with_consumer( + self, consumer: BackgroundFileConsumer + ) -> "defer.Deferred[object]": d = self.deferred self.consumer = consumer self.consumer.registerProducer(self, False) return d -class BlockingStringWrite: - def __init__(self): - self.buffer = "" +class BlockingBytesWrite: + def __init__(self) -> None: + self.buffer = b"" self.closed = False self.write_lock = threading.Lock() - self._notify_write_deferred = None + self._notify_write_deferred: Optional[defer.Deferred] = None self._number_of_writes = 0 - def write(self, bytes): + def write(self, write_bytes: bytes) -> None: with self.write_lock: - self.buffer += bytes + self.buffer += write_bytes self._number_of_writes += 1 reactor.callFromThread(self._notify_write) - def close(self): + def close(self) -> None: self.closed = True - def _notify_write(self): + def _notify_write(self) -> None: "Called by write to indicate a write happened" with self.write_lock: if not self._notify_write_deferred: @@ -161,7 +176,9 @@ class BlockingStringWrite: d.callback(None) @defer.inlineCallbacks - def wait_for_n_writes(self, n): + def wait_for_n_writes( + self, n: int + ) -> Generator["defer.Deferred[object]", object, None]: "Wait for n writes to have happened" while True: with self.write_lock: diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 3c0ddd4f18..406c16cdcf 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -19,7 +19,7 @@ from tests.unittest import TestCase class ChunkSeqTests(TestCase): - def test_short_seq(self): + def test_short_seq(self) -> None: parts = chunk_seq("123", 8) self.assertEqual( @@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase): ["123"], ) - def test_long_seq(self): + def test_long_seq(self) -> None: parts = chunk_seq("abcdefghijklmnop", 8) self.assertEqual( @@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase): ["abcdefgh", "ijklmnop"], ) - def test_uneven_parts(self): + def test_uneven_parts(self) -> None: parts = chunk_seq("abcdefghijklmnop", 5) self.assertEqual( @@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase): ["abcde", "fghij", "klmno", "p"], ) - def test_empty_input(self): + def test_empty_input(self) -> None: parts: Iterable[Sequence] = chunk_seq([], 5) self.assertEqual( @@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase): class SortTopologically(TestCase): - def test_empty(self): + def test_empty(self) -> None: "Test that an empty graph works correctly" graph: Dict[int, List[int]] = {} self.assertEqual(list(sorted_topologically([], graph)), []) - def test_handle_empty_graph(self): + def test_handle_empty_graph(self) -> None: "Test that a graph where a node doesn't have an entry is treated as empty" graph: Dict[int, List[int]] = {} @@ -67,7 +67,7 @@ class SortTopologically(TestCase): # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) - def test_disconnected(self): + def test_disconnected(self) -> None: "Test that a graph with no edges work" graph: Dict[int, List[int]] = {1: [], 2: []} @@ -75,20 +75,20 @@ class SortTopologically(TestCase): # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) - def test_linear(self): + def test_linear(self) -> None: "Test that a simple `4 -> 3 -> 2 -> 1` graph works" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) - def test_subset(self): + def test_subset(self) -> None: "Test that only sorting a subset of the graph works" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) - def test_fork(self): + def test_fork(self) -> None: "Test that a forked graph works" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} @@ -96,13 +96,13 @@ class SortTopologically(TestCase): # always get the same one. self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) - def test_duplicates(self): + def test_duplicates(self) -> None: "Test that a graph with duplicate edges work" graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) - def test_multiple_paths(self): + def test_multiple_paths(self) -> None: "Test that a graph with multiple paths between two nodes work" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 2ad321e184..d64c162e1d 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -1,5 +1,21 @@ +# Copyright 2014-2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Generator, cast + import twisted.python.failure -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as _reactor from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -10,25 +26,30 @@ from synapse.logging.context import ( nested_logging_context, run_in_background, ) +from synapse.types import ISynapseReactor from synapse.util import Clock from .. import unittest +reactor = cast(ISynapseReactor, _reactor) + class LoggingContextTestCase(unittest.TestCase): - def _check_test_key(self, value): - self.assertEqual(current_context().name, value) + def _check_test_key(self, value: str) -> None: + context = current_context() + assert isinstance(context, LoggingContext) + self.assertEqual(context.name, value) - def test_with_context(self): + def test_with_context(self) -> None: with LoggingContext("test"): self._check_test_key("test") @defer.inlineCallbacks - def test_sleep(self): + def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]: clock = Clock(reactor) @defer.inlineCallbacks - def competing_callback(): + def competing_callback() -> Generator["defer.Deferred[object]", object, None]: with LoggingContext("competing"): yield clock.sleep(0) self._check_test_key("competing") @@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase): yield clock.sleep(0) self._check_test_key("one") - def _test_run_in_background(self, function): + def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred: sentinel_context = current_context() - callback_completed = [False] + callback_completed = False with LoggingContext("one"): # fire off function, but don't wait on it. d2 = run_in_background(function) - def cb(res): - callback_completed[0] = True + def cb(res: object) -> object: + nonlocal callback_completed + callback_completed = True return res d2.addCallback(cb) @@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase): # the logcontext is left in a sane state. d2 = defer.Deferred() - def check_logcontext(): - if not callback_completed[0]: + def check_logcontext() -> None: + if not callback_completed: reactor.callLater(0.01, check_logcontext) return @@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase): # test is done once d2 finishes return d2 - def test_run_in_background_with_blocking_fn(self): + def test_run_in_background_with_blocking_fn(self) -> defer.Deferred: @defer.inlineCallbacks - def blocking_function(): + def blocking_function() -> Generator["defer.Deferred[object]", object, None]: yield Clock(reactor).sleep(0) return self._test_run_in_background(blocking_function) - def test_run_in_background_with_non_blocking_fn(self): + def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred: @defer.inlineCallbacks - def nonblocking_function(): + def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]: with PreserveLoggingContext(): yield defer.succeed(None) return self._test_run_in_background(nonblocking_function) - def test_run_in_background_with_chained_deferred(self): + def test_run_in_background_with_chained_deferred(self) -> defer.Deferred: # a function which returns a deferred which looks like it has been # called, but is actually paused - def testfunc(): + def testfunc() -> defer.Deferred: return make_deferred_yieldable(_chained_deferred_function()) return self._test_run_in_background(testfunc) - def test_run_in_background_with_coroutine(self): - async def testfunc(): + def test_run_in_background_with_coroutine(self) -> defer.Deferred: + async def testfunc() -> None: self._check_test_key("one") d = Clock(reactor).sleep(0) self.assertIs(current_context(), SENTINEL_CONTEXT) @@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase): return self._test_run_in_background(testfunc) - def test_run_in_background_with_nonblocking_coroutine(self): - async def testfunc(): + def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred: + async def testfunc() -> None: self._check_test_key("one") return self._test_run_in_background(testfunc) @defer.inlineCallbacks - def test_make_deferred_yieldable(self): + def test_make_deferred_yieldable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: # a function which returns an incomplete deferred, but doesn't follow # the synapse rules. - def blocking_function(): - d = defer.Deferred() + def blocking_function() -> defer.Deferred: + d: defer.Deferred = defer.Deferred() reactor.callLater(0, d.callback, None) return d @@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") @defer.inlineCallbacks - def test_make_deferred_yieldable_with_chained_deferreds(self): + def test_make_deferred_yieldable_with_chained_deferreds( + self, + ) -> Generator["defer.Deferred[object]", object, None]: sentinel_context = current_context() with LoggingContext("one"): @@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase): # now it should be restored self._check_test_key("one") - def test_nested_logging_context(self): + def test_nested_logging_context(self) -> None: with LoggingContext("foo"): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.name, "foo-bar") @@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase): # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on # its callback list, so won't yet call any other new callbacks. -def _chained_deferred_function(): +def _chained_deferred_function() -> defer.Deferred: d = defer.succeed(None) - def cb(res): - d2 = defer.Deferred() + def cb(res: object) -> defer.Deferred: + d2: defer.Deferred = defer.Deferred() reactor.callLater(0, d2.callback, res) return d2 diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py index a2e08281e6..0dee69a6fe 100644 --- a/tests/util/test_logformatter.py +++ b/tests/util/test_logformatter.py @@ -23,7 +23,7 @@ class TestException(Exception): class LogFormatterTestCase(unittest.TestCase): - def test_formatter(self): + def test_formatter(self) -> None: formatter = LogFormatter() try: diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 67173a4f5b..1fc5a473f0 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -13,10 +13,11 @@ # limitations under the License. -from typing import List +from typing import List, Tuple from unittest.mock import Mock, patch from synapse.metrics.jemalloc import JemallocStats +from synapse.types import JsonDict from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache @@ -25,14 +26,14 @@ from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): - def test_get_set(self): - cache = LruCache(1) + def test_get_set(self) -> None: + cache: LruCache[str, str] = LruCache(1) cache["key"] = "value" self.assertEqual(cache.get("key"), "value") self.assertEqual(cache["key"], "value") - def test_eviction(self): - cache = LruCache(2) + def test_eviction(self) -> None: + cache: LruCache[int, int] = LruCache(2) cache[1] = 1 cache[2] = 2 @@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(2), 2) self.assertEqual(cache.get(3), 3) - def test_setdefault(self): - cache = LruCache(1) + def test_setdefault(self) -> None: + cache: LruCache[str, int] = LruCache(1) self.assertEqual(cache.setdefault("key", 1), 1) self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.setdefault("key", 2), 1) @@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache["key"] = 2 # Make sure overriding works. self.assertEqual(cache.get("key"), 2) - def test_pop(self): - cache = LruCache(1) + def test_pop(self) -> None: + cache: LruCache[str, int] = LruCache(1) cache["key"] = 1 self.assertEqual(cache.pop("key"), 1) self.assertEqual(cache.pop("key"), None) - def test_del_multi(self): - cache = LruCache(4, cache_type=TreeCache) + def test_del_multi(self) -> None: + # The type here isn't quite correct as they don't handle TreeCache well. + cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(("animal", "cat")), "mew") self.assertEqual(cache.get(("vehicles", "car")), "vroom") - cache.del_multi(("animal",)) + cache.del_multi(("animal",)) # type: ignore[arg-type] self.assertEqual(len(cache), 2) self.assertEqual(cache.get(("animal", "cat")), None) self.assertEqual(cache.get(("animal", "dog")), None) @@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". - def test_clear(self): - cache = LruCache(1) + def test_clear(self) -> None: + cache: LruCache[str, int] = LruCache(1) cache["key"] = 1 cache.clear() self.assertEqual(len(cache), 0) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) - def test_special_size(self): - cache = LruCache(10, "mycache") + def test_special_size(self) -> None: + cache: LruCache = LruCache(10, "mycache") self.assertEqual(cache.max_size, 100) class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): - def test_get(self): + def test_get(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value") self.assertFalse(m.called) @@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key", "value") self.assertEqual(m.call_count, 1) - def test_multi_get(self): + def test_multi_get(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value") self.assertFalse(m.called) @@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key", "value") self.assertEqual(m.call_count, 1) - def test_set(self): + def test_set(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key", "value") self.assertEqual(m.call_count, 1) - def test_pop(self): + def test_pop(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.pop("key") self.assertEqual(m.call_count, 1) - def test_del_multi(self): + def test_del_multi(self) -> None: m1 = Mock() m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, cache_type=TreeCache) + # The type here isn't quite correct as they don't handle TreeCache well. + cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) @@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertEqual(m3.call_count, 0) self.assertEqual(m4.call_count, 0) - cache.del_multi(("a",)) + cache.del_multi(("a",)) # type: ignore[arg-type] self.assertEqual(m1.call_count, 1) self.assertEqual(m2.call_count, 1) self.assertEqual(m3.call_count, 0) self.assertEqual(m4.call_count, 0) - def test_clear(self): + def test_clear(self) -> None: m1 = Mock() m2 = Mock() - cache = LruCache(5) + cache: LruCache[str, str] = LruCache(5) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertEqual(m1.call_count, 1) self.assertEqual(m2.call_count, 1) - def test_eviction(self): + def test_eviction(self) -> None: m1 = Mock(name="m1") m2 = Mock(name="m2") m3 = Mock(name="m3") - cache = LruCache(2) + cache: LruCache[str, str] = LruCache(2) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): class LruCacheSizedTestCase(unittest.HomeserverTestCase): - def test_evict(self): - cache = LruCache(5, size_callback=len) + def test_evict(self) -> None: + cache: LruCache[str, List[int]] = LruCache(5, size_callback=len) cache["key1"] = [0] cache["key2"] = [1, 2] cache["key3"] = [3] @@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): cache["key1"] = [] self.assertEqual(len(cache), 0) + assert isinstance(cache.cache, dict) cache.cache["key1"].drop_from_cache() self.assertIsNone( cache.pop("key1"), "Cache entry should have been evicted but wasn't" @@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): class TimeEvictionTestCase(unittest.HomeserverTestCase): """Test that time based eviction works correctly.""" - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("caches", {})["expiry_time"] = "30m" return config - def test_evict(self): + def test_evict(self) -> None: setup_expire_lru_cache_entries(self.hs) - cache = LruCache(5, clock=self.hs.get_clock()) + cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock()) # Check that we evict entries we haven't accessed for 30 minutes. cache["key1"] = 1 @@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): } ) @patch("synapse.util.caches.lrucache.get_jemalloc_stats") - def test_evict_memory(self, jemalloc_interface) -> None: + def test_evict_memory(self, jemalloc_interface: Mock) -> None: mock_jemalloc_class = Mock(spec=JemallocStats) jemalloc_interface.return_value = mock_jemalloc_class @@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): mock_jemalloc_class.get_stat.return_value = 924288000 setup_expire_lru_cache_entries(self.hs) - cache = LruCache(4, clock=self.hs.get_clock()) + cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock()) cache["key1"] = 1 cache["key2"] = 2 diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 40754a4711..f68377a05a 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -21,14 +21,14 @@ from tests.unittest import TestCase class MacaroonGeneratorTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, hs_clock = get_clock() self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret") self.other_macaroon_generator = MacaroonGenerator( hs_clock, "tesths", b"anothersecretkey" ) - def test_guest_access_token(self): + def test_guest_access_token(self) -> None: """Test the generation and verification of guest access tokens""" token = self.macaroon_generator.generate_guest_access_token("@user:tesths") user_id = self.macaroon_generator.verify_guest_token(token) @@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase): with self.assertRaises(MacaroonVerificationFailedException): self.macaroon_generator.verify_guest_token(token) - def test_delete_pusher_token(self): + def test_delete_pusher_token(self) -> None: """Test the generation and verification of delete_pusher tokens""" token = self.macaroon_generator.generate_delete_pusher_token( "@user:tesths", "m.mail", "john@example.com" @@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_oidc_session_token(self): + def test_oidc_session_token(self) -> None: """Test the generation and verification of OIDC session cookies""" state = "arandomstate" session_data = OidcSessionData( diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 89d8656634..5b327b390e 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -13,16 +13,19 @@ # limitations under the License. from typing import Optional +from twisted.internet.defer import Deferred + from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRatelimitSettings from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import get_clock +from tests.server import ThreadedMemoryReactorClock, get_clock from tests.unittest import TestCase from tests.utils import default_config class FederationRateLimiterTestCase(TestCase): - def test_ratelimit(self): + def test_ratelimit(self) -> None: """A simple test with the default values""" reactor, clock = get_clock() rc_config = build_rc_config() @@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase): # shouldn't block self.successResultOf(d1) - def test_concurrent_limit(self): + def test_concurrent_limit(self) -> None: """Test what happens when we hit the concurrent limit""" reactor, clock = get_clock() rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) @@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase): cm2.__exit__(None, None, None) self.successResultOf(d3) - def test_sleep_limit(self): + def test_sleep_limit(self) -> None: """Test what happens when we hit the sleep limit""" reactor, clock = get_clock() rc_config = build_rc_config( @@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase): self.assertAlmostEqual(sleep_time, 500, places=3) -def _await_resolution(reactor, d): +def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float: """advance the clock until the deferred completes. Returns the number of milliseconds it took to complete. @@ -90,7 +93,7 @@ def _await_resolution(reactor, d): return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings: Optional[dict] = None): +def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings: config_dict = default_config("test") config_dict.update(settings or {}) config = HomeServerConfig() diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 26cb71c640..9529ee53c8 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase class RetryLimiterTestCase(HomeserverTestCase): - def test_new_destination(self): + def test_new_destination(self) -> None: """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) @@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase): new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) - def test_limiter(self): + def test_limiter(self) -> None: """General test case which walks through the process of a failing request""" store = self.hs.get_datastores().main diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 5da04362a9..bc93de62eb 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase): acquired_d: "Deferred[None]" = Deferred() unblock_d: "Deferred[None]" = Deferred() - async def reader_or_writer(): + async def reader_or_writer() -> str: async with read_or_write(key): acquired_d.callback(None) await unblock_d @@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase): d.called, msg="deferred %d was unexpectedly resolved" % (i + n) ) - def test_rwlock(self): + def test_rwlock(self) -> None: rwlock = ReadWriteLock() key = "key" @@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase): _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") self.assertTrue(acquired_d.called) - def test_lock_handoff_to_nonblocking_writer(self): + def test_lock_handoff_to_nonblocking_writer(self) -> None: """Test a writer handing the lock to another writer that completes instantly.""" rwlock = ReadWriteLock() key = "key" @@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase): d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") self.assertTrue(d3.called) - def test_cancellation_while_holding_read_lock(self): + def test_cancellation_while_holding_read_lock(self) -> None: """Test cancellation while holding a read lock. A waiting writer should be given the lock when the reader holding the lock is @@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase): ) self.assertEqual("write completed", self.successResultOf(writer_d)) - def test_cancellation_while_holding_write_lock(self): + def test_cancellation_while_holding_write_lock(self) -> None: """Test cancellation while holding a write lock. A waiting reader should be given the lock when the writer holding the lock is @@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase): ) self.assertEqual("read completed", self.successResultOf(reader_d)) - def test_cancellation_while_waiting_for_read_lock(self): + def test_cancellation_while_waiting_for_read_lock(self) -> None: """Test cancellation while waiting for a read lock. Tests that cancelling a waiting reader: @@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase): ) self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) - def test_cancellation_while_waiting_for_write_lock(self): + def test_cancellation_while_waiting_for_write_lock(self) -> None: """Test cancellation while waiting for a write lock. Tests that cancelling a waiting writer: diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 9ed01f7e0c..1b0fa52ad1 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): Tests for StreamChangeCache. """ - def test_prefilled_cache(self): + def test_prefilled_cache(self) -> None: """ Providing a prefilled cache to StreamChangeCache will result in a cache with the prefilled-cache entered in. @@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) - def test_has_entity_changed(self): + def test_has_entity_changed(self) -> None: """ StreamChangeCache.entity_has_changed will mark entities as changed, and has_entity_changed will observe the changed entities. @@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - def test_entity_has_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self) -> None: """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): ) self.assertIsNone(cache.get_all_entities_changed(1)) - def test_get_all_entities_changed(self): + def test_get_all_entities_changed(self) -> None: """ StreamChangeCache.get_all_entities_changed will return all changed entities since the given position. If the position is before the start @@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): r = cache.get_all_entities_changed(3) self.assertTrue(r == ok1 or r == ok2) - def test_has_any_entity_changed(self): + def test_has_any_entity_changed(self) -> None: """ StreamChangeCache.has_any_entity_changed will return True if any entities have been changed since the provided stream position, and @@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3)) - def test_get_entities_changed(self): + def test_get_entities_changed(self) -> None: """ StreamChangeCache.get_entities_changed will return the entities in the given list that have changed since the provided stream ID. If the @@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net"}, ) - def test_max_pos(self): + def test_max_pos(self) -> None: """ StreamChangeCache.get_max_pos_of_last_change will return the most recent point where the entity could have changed. If the entity is not diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py index ad4dd7f007..f137e05191 100644 --- a/tests/util/test_stringutils.py +++ b/tests/util/test_stringutils.py @@ -19,7 +19,7 @@ from .. import unittest class StringUtilsTestCase(unittest.TestCase): - def test_client_secret_regex(self): + def test_client_secret_regex(self) -> None: """Ensure that client_secret does not contain illegal characters""" good = [ "abcde12345", @@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase): with self.assertRaises(SynapseError): assert_valid_client_secret(client_secret) - def test_base62_encode(self): + def test_base62_encode(self) -> None: self.assertEqual("0", base62_encode(0)) self.assertEqual("10", base62_encode(62)) self.assertEqual("1c", base62_encode(100)) diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py index d957b953bb..3b35b8e4ec 100644 --- a/tests/util/test_threepids.py +++ b/tests/util/test_threepids.py @@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase class CanonicaliseEmailTests(HomeserverTestCase): - def test_no_at(self): + def test_no_at(self) -> None: with self.assertRaises(ValueError): canonicalise_email("address-without-at.bar") - def test_two_at(self): + def test_two_at(self) -> None: with self.assertRaises(ValueError): canonicalise_email("foo@foo@test.bar") - def test_bad_format(self): + def test_bad_format(self) -> None: with self.assertRaises(ValueError): canonicalise_email("user@bad.example.net@good.example.com") - def test_valid_format(self): + def test_valid_format(self) -> None: self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") - def test_domain_to_lower(self): + def test_domain_to_lower(self) -> None: self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") - def test_domain_with_umlaut(self): + def test_domain_with_umlaut(self) -> None: self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") - def test_address_casefold(self): + def test_address_casefold(self) -> None: self.assertEqual( canonicalise_email("Strauß@Example.com"), "strauss@example.com" ) - def test_address_trim(self): + def test_address_trim(self) -> None: self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 567cb18468..fe3b4dc6a4 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -19,7 +19,7 @@ from .. import unittest class TreeCacheTestCase(unittest.TestCase): - def test_get_set_onelevel(self): + def test_get_set_onelevel(self) -> None: cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" @@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.get(("b",)), "B") self.assertEqual(len(cache), 2) - def test_pop_onelevel(self): + def test_pop_onelevel(self) -> None: cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" @@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.get(("b",)), "B") self.assertEqual(len(cache), 1) - def test_get_set_twolevel(self): + def test_get_set_twolevel(self) -> None: cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" @@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.get(("b", "a")), "BA") self.assertEqual(len(cache), 3) - def test_pop_twolevel(self): + def test_pop_twolevel(self) -> None: cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" @@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.pop(("b", "a")), None) self.assertEqual(len(cache), 1) - def test_pop_mixedlevel(self): + def test_pop_mixedlevel(self) -> None: cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" @@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) - def test_clear(self): + def test_clear(self) -> None: cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" cache.clear() self.assertEqual(len(cache), 0) - def test_contains(self): + def test_contains(self) -> None: cache = TreeCache() cache[("a",)] = "A" self.assertTrue(("a",) in cache) diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index 0d5039de04..c9d22b6d8c 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -18,8 +18,8 @@ from .. import unittest class WheelTimerTestCase(unittest.TestCase): - def test_single_insert_fetch(self): - wheel = WheelTimer(bucket_size=5) + def test_single_insert_fetch(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj = object() wheel.insert(100, obj, 150) @@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(156), [obj]) self.assertListEqual(wheel.fetch(170), []) - def test_multi_insert(self): - wheel = WheelTimer(bucket_size=5) + def test_multi_insert(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj1 = object() obj2 = object() @@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(200), [obj3]) self.assertListEqual(wheel.fetch(210), []) - def test_insert_past(self): - wheel = WheelTimer(bucket_size=5) + def test_insert_past(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj = object() wheel.insert(100, obj, 50) self.assertListEqual(wheel.fetch(120), [obj]) - def test_insert_past_multi(self): - wheel = WheelTimer(bucket_size=5) + def test_insert_past_multi(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj1 = object() obj2 = object() -- cgit 1.5.1 From 6a8310f3dfe77acf59df2fe3e88a71b85b9b3ecc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 5 Dec 2022 09:00:59 -0500 Subject: Compare to the earliest known stream pos in the stream change cache. (#14435) The internal methods of the StreamChangeCache were inconsistently treating the earliest known stream position as valid. It is now treated as invalid, meaning the cache cannot determine if an entity at the earliest known stream position has changed or not. --- changelog.d/14435.bugfix | 1 + poetry.lock | 2 +- pyproject.toml | 3 +- synapse/util/caches/stream_change_cache.py | 142 +++++++++++++++++++++++------ tests/util/test_stream_change_cache.py | 38 +++----- 5 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 changelog.d/14435.bugfix (limited to 'tests') diff --git a/changelog.d/14435.bugfix b/changelog.d/14435.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14435.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/poetry.lock b/poetry.lock index 8c63134578..90b363a548 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1639,7 +1639,7 @@ url-preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "27811bd21d56ceeb0f68ded5a00375efcd1a004928f0736f5b02927ce8594cb0" +content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index af5ce2aa03..1368e4e688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,8 @@ pyasn1 = ">=0.1.9" pyasn1-modules = ">=0.0.7" bcrypt = ">=3.1.7" Pillow = ">=5.4.0" -sortedcontainers = ">=1.4.4" +# We use SortedDict.peekitem(), which was added in sortedcontainers 1.5.2. +sortedcontainers = ">=1.5.2" pymacaroons = ">=0.13.0" msgpack = ">=0.5.2" phonenumbers = ">=8.2.0" diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 666f4b6895..042de8d7c8 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -27,13 +27,17 @@ EntityType = str class StreamChangeCache: - """Keeps track of the stream positions of the latest change in a set of entities. + """ + Keeps track of the stream positions of the latest change in a set of entities. + + The entity will is typically a room ID or user ID, but can be any string. - Typically the entity will be a room or user id. + Can be queried for whether a specific entity has changed after a stream position + or for a list of changed entities after a stream position. See the individual + methods for more information. - Given a list of entities and a stream position, it will give a subset of - entities that may have changed since that position. If position key is too - old then the cache will simply return all given entities. + Only tracks to a maximum cache size, any position earlier than the earliest + known stream position must be treated as unknown. """ def __init__( @@ -45,16 +49,20 @@ class StreamChangeCache: ) -> None: self._original_max_size: int = max_size self._max_size = math.floor(max_size) - self._entity_to_key: Dict[EntityType, int] = {} - # map from stream id to the a set of entities which changed at that stream id. + # map from stream id to the set of entities which changed at that stream id. self._cache: SortedDict[int, Set[EntityType]] = SortedDict() + # map from entity to the stream ID of the latest change for that entity. + # + # Must be kept in sync with _cache. + self._entity_to_key: Dict[EntityType, int] = {} # the earliest stream_pos for which we can reliably answer # get_all_entities_changed. In other words, one less than the earliest # stream_pos for which we know _cache is valid. # self._earliest_known_stream_pos = current_stream_pos + self.name = name self.metrics = caches.register_cache( "cache", self.name, self._cache, resize_callback=self.set_cache_factor @@ -82,22 +90,46 @@ class StreamChangeCache: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos""" + """ + Returns True if the entity may have been updated after stream_pos. + + Args: + entity: The entity to check for changes. + stream_pos: The stream position to check for changes after. + + Return: + True if the entity may have been updated, this happens if: + * The given stream position is at or earlier than the earliest + known stream position. + * The given stream position is earlier than the latest change for + the entity. + + False otherwise: + * The entity is unknown. + * The given stream position is at or later than the latest change + for the entity. + """ assert isinstance(stream_pos, int) - if stream_pos < self._earliest_known_stream_pos: + # _cache is not valid at or before the earliest known stream position, so + # return that the entity has changed. + if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + # If the entity is unknown, it hasn't changed. latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: self.metrics.inc_hits() return False + # This is a known entity, return true if the stream position is earlier + # than the last change. if stream_pos < latest_entity_change_pos: self.metrics.inc_misses() return True + # Otherwise, the stream position is after the latest change: return false. self.metrics.inc_hits() return False @@ -105,15 +137,27 @@ class StreamChangeCache: self, entities: Collection[EntityType], stream_pos: int ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ - Returns subset of entities that have had new things since the given - position. Entities unknown to the cache will be returned. If the - position is too old it will just return the given list. + Returns the subset of the given entities that have had changes after the given position. + + Entities unknown to the cache will be returned. + + If the position is too old it will just return the given list. + + Args: + entities: Entities to check for changes. + stream_pos: The stream position to check for changes after. + + Return: + A subset of entities which have changed after the given stream position. + + This will be all entities if the given stream position is at or earlier + than the earliest known stream position. """ changed_entities = self.get_all_entities_changed(stream_pos) if changed_entities is not None: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the - # given iterable is already set that we can reuse, otherwise we + # given iterable is already a set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): @@ -130,29 +174,57 @@ class StreamChangeCache: return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed""" - assert type(stream_pos) is int + """ + Returns true if any entity has changed after the given stream position. + + Args: + stream_pos: The stream position to check for changes after. + + Return: + True if any entity has changed after the given stream position or + if the given stream position is at or earlier than the earliest + known stream position. + + False otherwise. + """ + assert isinstance(stream_pos, int) if not self._cache: # If the cache is empty, nothing can have changed. return False - if stream_pos >= self._earliest_known_stream_pos: - self.metrics.inc_hits() - return self._cache.bisect_right(stream_pos) < len(self._cache) - else: + # _cache is not valid at or before the earliest known stream position, so + # return that an entity has changed. + if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + self.metrics.inc_hits() + return stream_pos < self._cache.peekitem()[0] + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: - """Returns all entities that have had new things since the given - position. If the position is too old it will return None. + """ + Returns all entities that have had changes after the given position. + + If the stream change cache does not go far enough back, i.e. the position + is too old, it will return None. Returns the entities in the order that they were changed. + + Args: + stream_pos: The stream position to check for changes after. + + Return: + Entities which have changed after the given stream position. + + None if the given stream position is at or earlier than the earliest + known stream position. """ - assert type(stream_pos) is int + assert isinstance(stream_pos, int) - if stream_pos < self._earliest_known_stream_pos: + # _cache is not valid at or before the earliest known stream position, so + # return None to mark that it is unknown if an entity has changed. + if stream_pos <= self._earliest_known_stream_pos: return None changed_entities: List[EntityType] = [] @@ -162,11 +234,17 @@ class StreamChangeCache: return changed_entities def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: - """Informs the cache that the entity has been changed at the given - position. """ - assert type(stream_pos) is int + Informs the cache that the entity has been changed at the given position. + + Args: + entity: The entity to mark as changed. + stream_pos: The stream position to update the entity to. + """ + assert isinstance(stream_pos, int) + # For a change before _cache is valid (e.g. at or before the earliest known + # stream position) there's nothing to do. if stream_pos <= self._earliest_known_stream_pos: return @@ -189,6 +267,11 @@ class StreamChangeCache: self._evict() def _evict(self) -> None: + """ + Ensure the cache has not exceeded the maximum size. + + Evicts entries until it is at the maximum size. + """ # if the cache is too big, remove entries while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) @@ -199,5 +282,12 @@ class StreamChangeCache: def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an entity. + + Args: + entity: The entity to check. + + Return: + The stream position of the latest change for the given entity or + the earliest known stream position if the entitiy is unknown. """ return self._entity_to_key.get(entity, self._earliest_known_stream_pos) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 1b0fa52ad1..a29cc872f9 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -51,6 +51,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # return True, whether it's a known entity or not. self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) + self.assertTrue(cache.has_entity_changed("user@foo.com", 3)) + self.assertTrue(cache.has_entity_changed("not@here.website", 3)) def test_entity_has_changed_pops_off_start(self) -> None: """ @@ -65,15 +67,14 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # The cache is at the max size, 2 self.assertEqual(len(cache._cache), 2) + # The cache's earliest known position is 2. + self.assertEqual(cache._earliest_known_stream_pos, 2) # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) - self.assertEqual( - cache.get_all_entities_changed(2), - ["bar@baz.net", "user@elsewhere.org"], - ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) + self.assertIsNone(cache.get_all_entities_changed(2)) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -81,10 +82,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(2), + cache.get_all_entities_changed(3), ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertIsNone(cache.get_all_entities_changed(2)) def test_get_all_entities_changed(self) -> None: """ @@ -99,28 +100,15 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - r = cache.get_all_entities_changed(1) + r = cache.get_all_entities_changed(2) - # either of these are valid - ok1 = [ - "user@foo.com", - "bar@baz.net", - "anotheruser@foo.com", - "user@elsewhere.org", - ] - ok2 = [ - "user@foo.com", - "anotheruser@foo.com", - "bar@baz.net", - "user@elsewhere.org", - ] + # Results are ordered so either of these are valid. + ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"] + ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"] self.assertTrue(r == ok1 or r == ok2) - r = cache.get_all_entities_changed(2) - self.assertTrue(r == ok1[1:] or r == ok2[1:]) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(0), None) + self.assertEqual(cache.get_all_entities_changed(1), None) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) -- cgit 1.5.1 From cee9445884eb62c070fb0b03a112a862e8dea7c4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Dec 2022 20:19:14 +0000 Subject: Better return type for `get_all_entities_changed` (#14604) Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not. --- changelog.d/14604.bugfix | 1 + synapse/handlers/appservice.py | 4 +- synapse/handlers/presence.py | 12 ++-- synapse/handlers/sync.py | 6 +- synapse/handlers/typing.py | 8 +-- synapse/storage/databases/main/devices.py | 111 ++++++++++++++++++----------- synapse/util/caches/stream_change_cache.py | 52 ++++++++++---- tests/util/test_stream_change_cache.py | 20 +++--- 8 files changed, 138 insertions(+), 76 deletions(-) create mode 100644 changelog.d/14604.bugfix (limited to 'tests') diff --git a/changelog.d/14604.bugfix b/changelog.d/14604.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14604.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 66f5b8d108..f68027aaed 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -615,8 +615,8 @@ class ApplicationServicesHandler: ) # Fetch the users who have modified their device list since then. - users_with_changed_device_lists = ( - await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + users_with_changed_device_lists = await self.store.get_all_devices_changed( + from_key, to_key=new_key ) # Filter out any users the application service is not interested in diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1799174c2f..2af90b25a3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): if from_key is not None: # First get all users that have had a presence update - updated_users = stream_change_cache.get_all_entities_changed(from_key) + result = stream_change_cache.get_all_entities_changed(from_key) # Cross-reference users we're interested in with those that have had updates. - if updated_users is not None: + if result.hit: + updated_users = result.entities + # If we have the full list of changes for presence we can # simply check which ones share a room with the user. get_updates_counter.labels("stream").inc() @@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): updated_users = None if from_key: # Only return updates since the last sync - updated_users = self.store.presence_stream_cache.get_all_entities_changed( - from_key - ) + result = self.store.presence_stream_cache.get_all_entities_changed(from_key) + if result.hit: + updated_users = result.entities if updated_users is not None: # Get the actual presence update for each change diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8858b22dd..0b395a104d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1528,10 +1528,12 @@ class SyncHandler: # # If we don't have that info cached then we get all the users that # share a room with our user and check if those users have changed. - changed_users = self.store.get_cached_device_list_changes( + cache_result = self.store.get_cached_device_list_changes( since_token.device_list_key ) - if changed_users is not None: + if cache_result.hit: + changed_users = cache_result.entities + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a0ea719430..3f656ea4f5 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler): if last_id == current_id: return [], current_id, False - changed_rooms: Optional[ - Iterable[str] - ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) + result = self._typing_stream_change_cache.get_all_entities_changed(last_id) - if changed_rooms is None: + if result.hit: + changed_rooms: Iterable[str] = result.entities + else: changed_rooms = self._room_serials rows = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8ba995df3b..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.stream_change_cache import ( + AllEntitiesChangedResult, + StreamChangeCache, +) from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[List[str]]: + ) -> AllEntitiesChangedResult: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ return self._device_list_stream_cache.get_all_entities_changed(from_key) + @cancellable + async def get_all_devices_changed( + self, + from_key: int, + to_key: int, + ) -> Set[str]: + """Get all users whose devices have changed in the given range. + + Args: + from_key: The minimum device lists stream token to query device list + changes for, exclusive. + to_key: The maximum device lists stream token to query device list + changes for, inclusive. + + Returns: + The set of user_ids whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ + + result = self._device_list_stream_cache.get_all_entities_changed(from_key) + + if result.hit: + # We know which users might have changed devices. + if not result.entities: + # If no users then we can return early. + return set() + + # Otherwise we need to filter down the list + return await self.get_users_whose_devices_changed( + from_key, result.entities, to_key + ) + + # If the cache didn't tell us anything, we just need to query the full + # range. + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE ? < stream_id AND stream_id <= ? + """ + + rows = await self.db_pool.execute( + "get_all_devices_changed", + None, + sql, + from_key, + to_key, + ) + return {u for u, in rows} + @cancellable async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Collection[str]] = None, + user_ids: Collection[str], to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - user_ids_to_check: Optional[Collection[str]] - if user_ids is None: - # Get set of all users that have had device list changes since 'from_key' - user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( - from_key - ) - else: - # The same as above, but filter results to only those users in 'user_ids' - user_ids_to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) # If an empty set was returned, there's nothing to do. - if user_ids_to_check is not None and not user_ids_to_check: + if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - stream_id_where_clause = "stream_id > ?" - sql_args = [from_key] - - if to_key: - stream_id_where_clause += " AND stream_id <= ?" - sql_args.append(to_key) + if to_key is None: + to_key = self._device_list_id_gen.get_current_token() - sql = f""" + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ SELECT DISTINCT user_id FROM device_lists_stream - WHERE {stream_id_where_clause} + WHERE ? < stream_id AND stream_id <= ? AND %s """ - # If the stream change cache gave us no information, fetch *all* - # users between the stream IDs. - if user_ids_to_check is None: - txn.execute(sql, sql_args) - return {user_id for user_id, in txn} + changes: Set[str] = set() - # Otherwise, fetch changes for the given users. - else: - changes: Set[str] = set() - - # Query device changes with a batch of users at a time - for chunk in batch_iter(user_ids_to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + " AND " + clause, sql_args + args) - changes.update(user_id for user_id, in txn) + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql % (clause,), [from_key, to_key] + args) + changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 042de8d7c8..c8b17acb59 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -16,6 +16,7 @@ import logging import math from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union +import attr from sortedcontainers import SortedDict from synapse.util import caches @@ -26,6 +27,29 @@ logger = logging.getLogger(__name__) EntityType = str +@attr.s(auto_attribs=True, frozen=True, slots=True) +class AllEntitiesChangedResult: + """Return type of `get_all_entities_changed`. + + Callers must check that there was a cache hit, via `result.hit`, before + using the entities in `result.entities`. + + This specifically does *not* implement helpers such as `__bool__` to ensure + that callers do the correct checks. + """ + + _entities: Optional[List[EntityType]] + + @property + def hit(self) -> bool: + return self._entities is not None + + @property + def entities(self) -> List[EntityType]: + assert self._entities is not None + return self._entities + + class StreamChangeCache: """ Keeps track of the stream positions of the latest change in a set of entities. @@ -153,19 +177,19 @@ class StreamChangeCache: This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ - changed_entities = self.get_all_entities_changed(stream_pos) - if changed_entities is not None: + cache_result = self.get_all_entities_changed(stream_pos) + if cache_result.hit: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the # given iterable is already a set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): - result = entities.intersection(changed_entities) - elif len(changed_entities) < len(entities): - result = set(changed_entities).intersection(entities) + result = entities.intersection(cache_result.entities) + elif len(cache_result.entities) < len(entities): + result = set(cache_result.entities).intersection(entities) else: - result = set(entities).intersection(changed_entities) + result = set(entities).intersection(cache_result.entities) self.metrics.inc_hits() else: result = set(entities) @@ -202,12 +226,12 @@ class StreamChangeCache: self.metrics.inc_hits() return stream_pos < self._cache.peekitem()[0] - def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult: """ Returns all entities that have had changes after the given position. - If the stream change cache does not go far enough back, i.e. the position - is too old, it will return None. + If the stream change cache does not go far enough back, i.e. the + position is too old, it will return None. Returns the entities in the order that they were changed. @@ -215,23 +239,21 @@ class StreamChangeCache: stream_pos: The stream position to check for changes after. Return: - Entities which have changed after the given stream position. - - None if the given stream position is at or earlier than the earliest - known stream position. + A class indicating if we have the requested data cached, and if so + includes the entities in the order they were changed. """ assert isinstance(stream_pos, int) # _cache is not valid at or before the earliest known stream position, so # return None to mark that it is unknown if an entity has changed. if stream_pos <= self._earliest_known_stream_pos: - return None + return AllEntitiesChangedResult(None) changed_entities: List[EntityType] = [] for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): changed_entities.extend(self._cache[k]) - return changed_entities + return AllEntitiesChangedResult(changed_entities) def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """ diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index a29cc872f9..0305741c99 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertIsNone(cache.get_all_entities_changed(2)) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(2).hit) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(3), + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(2)) + self.assertFalse(cache.get_all_entities_changed(2).hit) def test_get_all_entities_changed(self) -> None: """ @@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # Results are ordered so either of these are valid. ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"] ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"] - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(1), None) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(1).hit) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) @@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "anotheruser@foo.com", ] r = cache.get_all_entities_changed(3) - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) def test_has_any_entity_changed(self) -> None: """ -- cgit 1.5.1 From cb59e080627745d089d073d9dac276362d9abaf6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 6 Dec 2022 09:52:55 +0000 Subject: Improve logging and opentracing for to-device message handling (#14598) A batch of changes intended to make it easier to trace to-device messages through the system. The intention here is that a client can set a property org.matrix.msgid in any to-device message it sends. That ID is then included in any tracing or logging related to the message. (Suggestions as to where this field should be documented welcome. I'm not enthusiastic about speccing it - it's very much an optional extra to help with debugging.) I've also generally improved the data we send to opentracing for these messages. --- changelog.d/14598.feature | 1 + synapse/api/constants.py | 3 + synapse/federation/sender/per_destination_queue.py | 2 +- synapse/handlers/appservice.py | 3 - synapse/handlers/devicemessage.py | 36 +++++---- synapse/handlers/sync.py | 26 ++++-- synapse/logging/opentracing.py | 11 ++- synapse/rest/client/sendtodevice.py | 1 - synapse/storage/databases/main/deviceinbox.py | 92 ++++++++++++++++++---- tests/handlers/test_appservice.py | 7 +- 10 files changed, 136 insertions(+), 46 deletions(-) create mode 100644 changelog.d/14598.feature (limited to 'tests') diff --git a/changelog.d/14598.feature b/changelog.d/14598.feature new file mode 100644 index 0000000000..88d561e286 --- /dev/null +++ b/changelog.d/14598.feature @@ -0,0 +1 @@ +Improve opentracing and logging for to-device message handling. \ No newline at end of file diff --git a/synapse/api/constants.py b/synapse/api/constants.py index bc04a0755b..89723d24fa 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -230,6 +230,9 @@ class EventContentFields: # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" + # an unspecced field added to to-device messages to identify them uniquely-ish + TO_DEVICE_MSGID: Final = "org.matrix.msgid" + class RoomTypes: """Understood values of the room_type field of m.room.create events.""" diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 5af2784f1e..ffc9d95ee7 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -641,7 +641,7 @@ class PerDestinationQueue: if not message_id: continue - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id) edus = [ Edu( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index f68027aaed..5d1d21cdc8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -578,9 +578,6 @@ class ApplicationServicesHandler: device_id, ), messages in recipient_device_to_messages.items(): for message_json in messages: - # Remove 'message_id' from the to-device message, as it's an internal ID - message_json.pop("message_id", None) - message_payload.append( { "to_user_id": user_id, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 444c08bc2e..75e89850f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict -from synapse.api.constants import EduTypes, ToDeviceEventTypes +from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background @@ -216,14 +216,24 @@ class DeviceMessageHandler: """ sender_user_id = requester.user.to_string() - message_id = random_string(16) - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) - - log_kv({"number_of_to_device_messages": len(messages)}) - set_tag("sender", sender_user_id) + set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) + set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) local_messages = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): + # add an opentracing log entry for each message + for device_id, message_content in by_device.items(): + log_kv( + { + "event": "send_to_device_message", + "user_id": user_id, + "device_id": device_id, + EventContentFields.TO_DEVICE_MSGID: message_content.get( + EventContentFields.TO_DEVICE_MSGID + ), + } + ) + # Ratelimit local cross-user key requests by the sending device. if ( message_type == ToDeviceEventTypes.RoomKeyRequest @@ -233,6 +243,7 @@ class DeviceMessageHandler: requester, (sender_user_id, requester.device_id) ) if not allowed: + log_kv({"message": f"dropping key requests to {user_id}"}) logger.info( "Dropping room_key_request from %s to %s due to rate limit", sender_user_id, @@ -247,18 +258,11 @@ class DeviceMessageHandler: "content": message_content, "type": message_type, "sender": sender_user_id, - "message_id": message_id, } for device_id, message_content in by_device.items() } if messages_by_device: local_messages[user_id] = messages_by_device - log_kv( - { - "user_id": user_id, - "device_id": list(messages_by_device), - } - ) else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device @@ -267,7 +271,11 @@ class DeviceMessageHandler: remote_edu_contents = {} for destination, messages in remote_messages.items(): - log_kv({"destination": destination}) + # The EDU contains a "message_id" property which is used for + # idempotence. Make up a random one. + message_id = random_string(16) + log_kv({"destination": destination, "message_id": message_id}) + remote_edu_contents[destination] = { "messages": messages, "sender": sender_user_id, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0b395a104d..dace9b606f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,14 +31,20 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context -from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary @@ -1586,6 +1592,7 @@ class SyncHandler: else: return DeviceListUpdates() + @trace async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" ) -> None: @@ -1605,11 +1612,16 @@ class SyncHandler: ) for message in messages: - # We pop here as we shouldn't be sending the message ID down - # `/sync` - message_id = message.pop("message_id", None) - if message_id: - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + log_kv( + { + "event": "to_device_message", + "sender": message["sender"], + "type": message["type"], + EventContentFields.TO_DEVICE_MSGID: message["content"].get( + EventContentFields.TO_DEVICE_MSGID + ), + } + ) logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index b69060854f..a705af8356 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -292,8 +292,15 @@ logger = logging.getLogger(__name__) class SynapseTags: - # The message ID of any to_device message processed - TO_DEVICE_MESSAGE_ID = "to_device.message_id" + # The message ID of any to_device EDU processed + TO_DEVICE_EDU_ID = "to_device.edu_id" + + # Details about to-device messages + TO_DEVICE_TYPE = "to_device.type" + TO_DEVICE_SENDER = "to_device.sender" + TO_DEVICE_RECIPIENT = "to_device.recipient" + TO_DEVICE_RECIPIENT_DEVICE = "to_device.recipient_device" + TO_DEVICE_MSGID = "to_device.msgid" # client-generated ID # Whether the sync response has new data to be returned to the client. SYNC_RESULT = "sync.new_data" diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index 46a8b03829..55d52f0b28 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -46,7 +46,6 @@ class SendToDeviceRestServlet(servlet.RestServlet): def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 73c95ffb6f..48a54d9cb8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -26,8 +26,15 @@ from typing import ( cast, ) +from synapse.api.constants import EventContentFields from synapse.logging import issue9533_logger -from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -397,6 +404,17 @@ class DeviceInboxWorkerStore(SQLBaseStore): (recipient_user_id, recipient_device_id), [] ).append(message_dict) + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that @@ -678,12 +696,35 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - if remote_messages_by_destination: - issue9533_logger.debug( - "Queued outgoing to-device messages with stream_id %i for %s", - stream_id, - list(remote_messages_by_destination.keys()), - ) + for destination, edu in remote_messages_by_destination.items(): + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Queued outgoing to-device messages with " + "stream_id %i, EDU message_id %s, type %s for %s: %s", + stream_id, + edu["message_id"], + edu["type"], + destination, + [ + f"{user_id}/{device_id} (msgid " + f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})" + for (user_id, messages_by_device) in edu["messages"].items() + for (device_id, msg) in messages_by_device.items() + ], + ) + + for (user_id, messages_by_device) in edu["messages"].items(): + for (device_id, msg) in messages_by_device.items(): + with start_active_span("store_outgoing_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) + set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg.get(EventContentFields.TO_DEVICE_MSGID), + ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() @@ -801,7 +842,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] - message_json = json_encoder.encode(messages_by_device[device_id]) + + with start_active_span("serialise_to_device_message"): + msg = messages_by_device[device_id] + set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + message_json = json_encoder.encode(msg) + messages_json_for_user[device_id] = message_json if messages_json_for_user: @@ -821,15 +874,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - issue9533_logger.debug( - "Stored to-device messages with stream_id %i for %s", - stream_id, - [ - (user_id, device_id) - for (user_id, messages_by_device) in local_by_user_then_device.items() - for device_id in messages_by_device.keys() - ], - ) + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Stored to-device messages with stream_id %i: %s", + stream_id, + [ + f"{user_id}/{device_id} (msgid " + f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})" + for ( + user_id, + messages_by_device, + ) in messages_by_user_then_device.items() + for (device_id, msg) in messages_by_device.items() + ], + ) class DeviceInboxBackgroundUpdateStore(SQLBaseStore): diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9ed26d87a7..57bfbd7734 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -765,7 +765,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)] messages = { self.exclusive_as_user: { - device_id: to_device_message_content for device_id in fake_device_ids + device_id: { + "type": "test_to_device_message", + "sender": "@some:sender", + "content": to_device_message_content, + } + for device_id in fake_device_ids } } -- cgit 1.5.1 From 9e82caac45cd8eccd7b22c60c2cdbeec9aab3a2d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 6 Dec 2022 15:48:42 +0000 Subject: Faster remote room joins: unblock tasks waiting for full room state when the un-partial-stating of that room is received over the replication stream. [rei:frrj/streams/unpsr] (#14474) --- changelog.d/14474.misc | 1 + synapse/replication/tcp/client.py | 11 ++++ .../replication/tcp/streams/test_partial_state.py | 65 ++++++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 changelog.d/14474.misc create mode 100644 tests/replication/tcp/streams/test_partial_state.py (limited to 'tests') diff --git a/changelog.d/14474.misc b/changelog.d/14474.misc new file mode 100644 index 0000000000..deccd4e91a --- /dev/null +++ b/changelog.d/14474.misc @@ -0,0 +1 @@ +Faster remote room joins: stream the un-partial-stating of rooms over replication. \ No newline at end of file diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 18252a2958..b4dad47b45 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -36,12 +36,14 @@ from synapse.replication.tcp.streams import ( TagAccountDataStream, ToDeviceStream, TypingStream, + UnPartialStatedRoomStream, ) from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStreamRow from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID from synapse.util.async_helpers import Linearizer, timeout_deferred from synapse.util.metrics import Measure @@ -117,6 +119,7 @@ class ReplicationDataHandler: self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() + self._state_storage_controller = hs.get_storage_controllers().state self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() @@ -236,6 +239,14 @@ class ReplicationDataHandler: self.notifier.notify_user_joined_room( row.data.event_id, row.data.room_id ) + elif stream_name == UnPartialStatedRoomStream.NAME: + for row in rows: + assert isinstance(row, UnPartialStatedRoomStreamRow) + + # Wake up any tasks waiting for the room to be un-partial-stated. + self._state_storage_controller.notify_room_un_partial_stated( + row.room_id + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py new file mode 100644 index 0000000000..2c10eab4db --- /dev/null +++ b/tests/replication/tcp/streams/test_partial_state.py @@ -0,0 +1,65 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet.defer import ensureDeferred + +from synapse.rest.client import room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + + +class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): + servlets = [room.register_servlets] + hijack_auth = True + user_id = "@bob:test" + + def setUp(self): + super().setUp() + self.store = self.hs.get_datastores().main + + def test_un_partial_stated_room_unblocks_over_replication(self) -> None: + """ + Tests that, when a room is un-partial-stated on another worker, + pending calls to `await_full_state` get unblocked. + """ + + # Make a room. + room_id = self.helper.create_room_as("@bob:test") + # Mark the room as partial-stated. + self.get_success( + self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1") + ) + + worker = self.make_worker_hs("synapse.app.generic_worker") + + # On the worker, attempt to get the current hosts in the room + d = ensureDeferred( + worker.get_storage_controllers().state.get_current_hosts_in_room(room_id) + ) + + self.reactor.advance(0.1) + + # This should block + self.assertFalse( + d.called, "get_current_hosts_in_room/await_full_state did not block" + ) + + # On the master, clear the partial state flag. + self.get_success(self.store.clear_partial_state_room(room_id)) + + self.reactor.advance(0.1) + + # The worker should have unblocked + self.assertTrue( + d.called, "get_current_hosts_in_room/await_full_state did not unblock" + ) -- cgit 1.5.1 From cf1059d045640485a5a0b1e3d945b796b0e6f228 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 7 Dec 2022 11:19:43 +0000 Subject: Fix a long-standing bug where the user directory would return 1 more row than requested. (#14631) --- changelog.d/14631.bugfix | 1 + synapse/rest/client/user_directory.py | 4 ++-- synapse/storage/databases/main/user_directory.py | 2 +- tests/storage/test_user_directory.py | 6 ++++++ 4 files changed, 10 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14631.bugfix (limited to 'tests') diff --git a/changelog.d/14631.bugfix b/changelog.d/14631.bugfix new file mode 100644 index 0000000000..c5376bab9f --- /dev/null +++ b/changelog.d/14631.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory would return 1 more row than requested. \ No newline at end of file diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index 116c982ce6..4670fad608 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -63,8 +63,8 @@ class UserDirectorySearchRestServlet(RestServlet): body = parse_json_object_from_request(request) - limit = body.get("limit", 10) - limit = min(limit, 50) + limit = int(body.get("limit", 10)) + limit = max(min(limit, 50), 0) try: search_term = body["search_term"] diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 044435deab..af9952f513 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -886,7 +886,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): limited = len(results) > limit - return {"limited": limited, "results": results} + return {"limited": limited, "results": results[0:limit]} def _parse_query_sqlite(search_term: str) -> str: diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 5b60cf5285..88c7d5fec0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -448,6 +448,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_limit_correct(self) -> None: + r = self.get_success(self.store.search_user_dir(ALICE, "bob", 1)) + self.assertTrue(r["limited"]) + self.assertEqual(1, len(r["results"])) + @override_config({"user_directory": {"search_all_users": True}}) def test_search_user_dir_stop_words(self) -> None: """Tests that a user can look up another user by searching for the start if its -- cgit 1.5.1 From 60c3fea3271468dd1f9e9c5fae2d22dd9778293b Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 7 Dec 2022 17:35:41 +0000 Subject: Reject receipt requests with invalid room or event IDs. (#14632) If the room or event IDs are empty or of an invalid form they should be rejected. --- changelog.d/14632.bugfix | 1 + synapse/rest/client/receipts.py | 5 ++- tests/rest/client/test_receipts.py | 76 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14632.bugfix create mode 100644 tests/rest/client/test_receipts.py (limited to 'tests') diff --git a/changelog.d/14632.bugfix b/changelog.d/14632.bugfix new file mode 100644 index 0000000000..323d10f1b0 --- /dev/null +++ b/changelog.d/14632.bugfix @@ -0,0 +1 @@ +Reject invalid read receipt requests with empty room or event IDs. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 18a282b22c..28b7d30ea8 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import EventID, JsonDict, RoomID from ._base import client_patterns @@ -56,6 +56,9 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) + if not RoomID.is_valid(room_id) or not event_id.startswith(EventID.SIGIL): + raise SynapseError(400, "A valid room ID and event ID must be specified") + if receipt_type not in self._known_receipt_types: raise SynapseError( 400, diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py new file mode 100644 index 0000000000..2a7fcea386 --- /dev/null +++ b/tests/rest/client/test_receipts.py @@ -0,0 +1,76 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, receipts, register +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class ReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + receipts.register_servlets, + synapse.rest.admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + + def test_send_receipt(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/m.read/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + def test_send_receipt_invalid_room_id(self) -> None: + channel = self.make_request( + "POST", + "/rooms/not-a-room-id/receipt/m.read/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["error"], "A valid room ID and event ID must be specified" + ) + + def test_send_receipt_invalid_event_id(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/m.read/not-an-event-id", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["error"], "A valid room ID and event ID must be specified" + ) + + def test_send_receipt_invalid_receipt_type(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/invalid-receipt-type/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) -- cgit 1.5.1 From da777207528513c858395758bf4c023da2c2c1a3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:35:49 -0500 Subject: Check the stream position before checking if the cache is empty. (#14639) An empty cache does not mean the entity has no changed, if it is earlier than the earliest known stream position return that the entity *has* changed since the cache cannot accurately answer that query. --- changelog.d/14639.bugfix | 1 + synapse/util/caches/stream_change_cache.py | 9 +++++---- tests/util/test_stream_change_cache.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14639.bugfix (limited to 'tests') diff --git a/changelog.d/14639.bugfix b/changelog.d/14639.bugfix new file mode 100644 index 0000000000..8730b10afe --- /dev/null +++ b/changelog.d/14639.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory and room/user stats might be out of sync. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c8b17acb59..1657459549 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -213,16 +213,17 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - if not self._cache: - # If the cache is empty, nothing can have changed. - return False - # _cache is not valid at or before the earliest known stream position, so # return that an entity has changed. if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + # If the cache is empty, nothing can have changed. + if not self._cache: + self.metrics.inc_misses() + return False + self.metrics.inc_hits() return stream_pos < self._cache.peekitem()[0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 0305741c99..3df053493b 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -144,9 +144,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): """ cache = StreamChangeCache("#test", 1) - # With no entities, it returns False for the past, present, and future. - self.assertFalse(cache.has_any_entity_changed(0)) - self.assertFalse(cache.has_any_entity_changed(1)) + # With no entities, it returns True for the past, present, and False for + # the future. + self.assertTrue(cache.has_any_entity_changed(0)) + self.assertTrue(cache.has_any_entity_changed(1)) self.assertFalse(cache.has_any_entity_changed(2)) # We add an entity -- cgit 1.5.1 From 9d8a3234ba1d3ff831a7647f45c67946773d88a7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:37:05 -0500 Subject: Respond with proper error responses on unknown paths. (#14621) Returns a proper 404 with an errcode of M_RECOGNIZED for unknown endpoints per MSC3743. --- changelog.d/14621.bugfix | 1 + synapse/api/errors.py | 6 ++---- synapse/http/server.py | 19 ++++++++++++++++++- synapse/rest/media/v1/media_repository.py | 4 ++-- synapse/util/httpresourcetree.py | 6 ++++-- tests/rest/admin/test_user.py | 2 +- tests/rest/client/test_login_token_request.py | 4 ++-- tests/rest/client/test_rendezvous.py | 2 +- tests/test_server.py | 2 +- 9 files changed, 32 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14621.bugfix (limited to 'tests') diff --git a/changelog.d/14621.bugfix b/changelog.d/14621.bugfix new file mode 100644 index 0000000000..cb95a87d92 --- /dev/null +++ b/changelog.d/14621.bugfix @@ -0,0 +1 @@ +Return spec-compliant JSON errors when unknown endpoints are requested. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e2cfcea0f2..76ef12ed3a 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -300,10 +300,8 @@ class InteractiveAuthIncompleteError(Exception): class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" - def __init__( - self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED - ): - super().__init__(400, msg, errcode) + def __init__(self, msg: str = "Unrecognized request", code: int = 400): + super().__init__(code, msg, Codes.UNRECOGNIZED) class NotFoundError(SynapseError): diff --git a/synapse/http/server.py b/synapse/http/server.py index 051a1899a0..2563858f3c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -577,7 +577,24 @@ def _unrecognised_request_handler(request: Request) -> NoReturn: Args: request: Unused, but passed in to match the signature of ServletCallback. """ - raise UnrecognizedRequestError() + raise UnrecognizedRequestError(code=404) + + +class UnrecognizedRequestResource(resource.Resource): + """ + Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an + errcode of M_UNRECOGNIZED. + """ + + def render(self, request: SynapseRequest) -> int: + f = failure.Failure(UnrecognizedRequestError(code=404)) + return_json_error(f, request, None) + # A response has already been sent but Twisted requires either NOT_DONE_YET + # or the response bytes as a return value. + return NOT_DONE_YET + + def getChild(self, name: str, request: Request) -> resource.Resource: + return self class RootRedirect(resource.Resource): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 40b0d39eb2..c70e1837af 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -24,7 +24,6 @@ from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred -from twisted.web.resource import Resource from synapse.api.errors import ( FederationDeniedError, @@ -35,6 +34,7 @@ from synapse.api.errors import ( ) from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement +from synapse.http.server import UnrecognizedRequestResource from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process @@ -1046,7 +1046,7 @@ class MediaRepository: return removed_media, len(removed_media) -class MediaRepositoryResource(Resource): +class MediaRepositoryResource(UnrecognizedRequestResource): """File uploading and downloading. Uploads are POSTed to a resource which returns a token which is used to GET diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index a0606851f7..39fab4fe06 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -15,7 +15,9 @@ import logging from typing import Dict -from twisted.web.resource import NoResource, Resource +from twisted.web.resource import Resource + +from synapse.http.server import UnrecognizedRequestResource logger = logging.getLogger(__name__) @@ -49,7 +51,7 @@ def create_resource_tree( for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource: Resource = NoResource() + child_resource: Resource = UnrecognizedRequestResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e8c9457794..5c1ced355f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -3994,7 +3994,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): """ Tests that shadow-banning for a user that is not a local returns a 400 """ - url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/shadow_ban" channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(400, channel.code, msg=channel.json_body) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index c2e1e08811..6aedc1a11c 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -48,13 +48,13 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): def test_disabled(self) -> None: channel = self.make_request("POST", endpoint, {}, access_token=None) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) self.register_user(self.user, self.password) token = self.login(self.user, self.password) channel = self.make_request("POST", endpoint, {}, access_token=token) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3882_enabled": True}}) def test_require_auth(self) -> None: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index ad00a476e1..c0eb5d01a6 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -36,7 +36,7 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase): def test_disabled(self) -> None: channel = self.make_request("POST", endpoint, {}, access_token=None) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) def test_redirect(self) -> None: diff --git a/tests/test_server.py b/tests/test_server.py index 2d9a0257d4..d67d7722a4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" ) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") -- cgit 1.5.1 From c2de2ca63060324cf2f80ddf3289b0fd7a4d861b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Dec 2022 09:37:07 +0000 Subject: Delete stale non-e2e devices for users, take 2 (#14595) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. --- changelog.d/14595.misc | 1 + synapse/handlers/device.py | 31 +++++++++++- synapse/storage/databases/main/devices.py | 79 ++++++++++++++++++++++++++++++- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14595.misc (limited to 'tests') diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14595.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d4750a32e6..7674c187ef 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,6 +52,7 @@ from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -421,6 +422,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Prune the user's device list if they already have a lot of devices. + await self._prune_too_many_devices(user_id) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -452,6 +456,31 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + # We don't want to block and try and delete tonnes of devices at once, + # so we cap the number of devices we delete synchronously. + first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] + await self.delete_devices(user_id, first_batch) + + if not remaining_device_ids: + return + + # Now spawn a background loop that deletes the rest. + async def _prune_too_many_devices_loop() -> None: + for batch in batch_iter(remaining_device_ids, 10): + await self.delete_devices(user_id, batch) + + await self.clock.sleep(1) + + run_as_background_process( + "_prune_too_many_devices_loop", _prune_too_many_devices_loop + ) + async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -481,7 +510,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e..08ccd46a2b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,6 +1569,72 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return [] + + # We prune everything older than N days. + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + if num_devices > 50: + # If the user has more than 50 devices, then we chose a last seen + # that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Now fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + ORDER BY last_seen + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> List[str]: + txn.execute(sql, (user_id, max_last_seen)) + return [device_id for device_id, in txn] + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1627,6 +1693,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1672,7 +1739,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + @cached(max_entries=0) + async def delete_device(self, user_id: str, device_id: str) -> None: + raise NotImplementedError() + + # Note: sometimes deleting rows out of `device_inbox` can take a long time, + # so we use a cache so that we deduplicate in flight requests to delete + # devices. + @cachedList(cached_method_name="delete_device", list_name="device_ids") + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: """Deletes several devices. Args: @@ -1709,6 +1784,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) + return {} + async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 49ad3c1324..a9af1babed 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,6 +169,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -189,7 +191,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1 From 94bc21e69f89ad873ad7a0deb6d9c4ff3cb480ef Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Dec 2022 13:31:32 +0000 Subject: Limit the number of devices we delete at once (#14649) --- changelog.d/14649.misc | 1 + synapse/handlers/device.py | 4 +++- synapse/storage/databases/main/devices.py | 11 ++++++++--- tests/handlers/test_device.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14649.misc (limited to 'tests') diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14649.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7674c187ef..c935c7be90 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -458,10 +458,12 @@ class DeviceHandler(DeviceWorkerHandler): async def _prune_too_many_devices(self, user_id: str) -> None: """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id) + device_ids = await self.store.check_too_many_devices_for_user(user_id, 100) if not device_ids: return + logger.info("Pruning %d old devices for user %s", len(device_ids), user_id) + # We don't want to block and try and delete tonnes of devices at once, # so we cap the number of devices we delete synchronously. first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 08ccd46a2b..95d4c0622d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,11 +1569,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + async def check_too_many_devices_for_user( + self, user_id: str, limit: int + ) -> List[str]: """Check if the user has a lot of devices, and if so return the set of devices we can prune. This does *not* return hidden devices or devices with E2E keys. + + Returns at most `limit` number of devices, ordered by last seen. """ num_devices = await self.db_pool.simple_select_one_onecol( @@ -1614,7 +1618,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # Now fetch the devices to delete. sql = """ - SELECT DISTINCT device_id FROM devices + SELECT device_id FROM devices LEFT JOIN e2e_device_keys_json USING (user_id, device_id) WHERE user_id = ? @@ -1622,12 +1626,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): AND last_seen < ? AND key_json IS NULL ORDER BY last_seen + LIMIT ? """ def check_too_many_devices_for_user_txn( txn: LoggingTransaction, ) -> List[str]: - txn.execute(sql, (user_id, max_last_seen)) + txn.execute(sql, (user_id, max_last_seen, limit)) return [device_id for device_id, in txn] return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a456bffd63..e51cac9b33 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,6 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler +from synapse.rest import admin +from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.util import Clock @@ -30,6 +32,12 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + account.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() @@ -229,6 +237,29 @@ class DeviceTestCase(unittest.HomeserverTestCase): NotFoundError, ) + def test_login_delete_old_devices(self) -> None: + """Delete old devices if the user already has too many.""" + + user_id = self.register_user("user", "pass") + + # Create a bunch of devices + for _ in range(50): + self.login("user", "pass") + self.reactor.advance(1) + + # Advance the clock for ages (as we only delete old devices) + self.reactor.advance(60 * 60 * 24 * 300) + + # Log in again to start the pruning + self.login("user", "pass") + + # Give the background job time to do its thing + self.reactor.pump([1.0] * 100) + + # We should now only have the most recent device. + devices = self.get_success(self.handler.get_devices_by_user(user_id)) + self.assertEqual(len(devices), 1) + def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. -- cgit 1.5.1 From 3ac412b4e2f8c5ba11dc962b8a9d871c1efdce9b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Dec 2022 12:36:32 -0500 Subject: Require types in tests.storage. (#14646) Adds missing type hints to `tests.storage` package and does not allow untyped definitions. --- changelog.d/14646.misc | 1 + mypy.ini | 14 +-- synapse/storage/databases/main/end_to_end_keys.py | 2 +- tests/storage/databases/main/test_deviceinbox.py | 10 +- tests/storage/databases/main/test_events_worker.py | 27 ++--- tests/storage/databases/main/test_lock.py | 18 +-- tests/storage/databases/main/test_receipts.py | 8 +- tests/storage/databases/main/test_room.py | 10 +- tests/storage/test__base.py | 2 +- tests/storage/test_account_data.py | 12 +- tests/storage/test_appservice.py | 22 ++-- tests/storage/test_base.py | 30 ++--- tests/storage/test_cleanup_extrems.py | 37 +++--- tests/storage/test_client_ips.py | 58 +++++----- tests/storage/test_database.py | 2 +- tests/storage/test_devices.py | 35 ++++-- tests/storage/test_directory.py | 12 +- tests/storage/test_e2e_room_keys.py | 8 +- tests/storage/test_end_to_end_keys.py | 15 ++- tests/storage/test_event_chain.py | 29 +++-- tests/storage/test_event_federation.py | 71 ++++++------ tests/storage/test_event_metrics.py | 2 +- tests/storage/test_events.py | 39 ++++--- tests/storage/test_keys.py | 9 +- tests/storage/test_monthly_active_users.py | 30 ++--- tests/storage/test_purge.py | 15 ++- tests/storage/test_receipts.py | 12 +- tests/storage/test_redaction.py | 125 ++++++++++++--------- tests/storage/test_rollback_worker.py | 15 ++- tests/storage/test_room.py | 24 ++-- tests/storage/test_room_search.py | 10 +- tests/storage/test_state.py | 46 +++++--- tests/storage/test_stream.py | 18 ++- tests/storage/test_transactions.py | 18 ++- tests/storage/test_txn_limit.py | 14 ++- .../util/test_partial_state_events_tracker.py | 30 ++--- 36 files changed, 489 insertions(+), 341 deletions(-) create mode 100644 changelog.d/14646.misc (limited to 'tests') diff --git a/changelog.d/14646.misc b/changelog.d/14646.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14646.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index c3fbd1a955..a4a1e4511a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.handlers.test_sso] +disallow_untyped_defs = True + [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True @@ -103,16 +106,7 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True -[mypy-tests.storage.test_id_generators] -disallow_untyped_defs = True - -[mypy-tests.storage.test_profile] -disallow_untyped_defs = True - -[mypy-tests.handlers.test_sso] -disallow_untyped_defs = True - -[mypy-tests.storage.test_user_directory] +[mypy-tests.storage.*] disallow_untyped_defs = True [mypy-tests.rest.*] diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 643c47d608..4c691642e2 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_device_keys_for_cs_api( self, - query_list: List[Tuple[str, Optional[str]]], + query_list: Collection[Tuple[str, Optional[str]]], include_displaynames: bool = True, ) -> Dict[str, Dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 50c20c5b92..373707b275 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest import admin from synapse.rest.client import devices +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): devices.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") - def test_background_remove_deleted_devices_from_device_inbox(self): + def test_background_remove_deleted_devices_from_device_inbox(self) -> None: """Test that the background task to delete old device_inboxes works properly.""" # create a valid device @@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): self.assertEqual(1, len(res)) self.assertEqual(res[0], "cur_device") - def test_background_remove_hidden_devices_from_device_inbox(self): + def test_background_remove_hidden_devices_from_device_inbox(self) -> None: """Test that the background task to delete hidden devices from device_inboxes works properly.""" diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5773172ab8..9f33afcca0 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs self.store: EventsWorkerStore = hs.get_datastores().main @@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.event_ids.append(event.event_id) - def test_simple(self): + def test_simple(self) -> None: with LoggingContext(name="test") as ctx: res = self.get_success( self.store.have_seen_events( @@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_persisting_event_invalidates_cache(self): + def test_persisting_event_invalidates_cache(self) -> None: """ Test to make sure that the `have_seen_event` cache is invalidated after we persist an event and returns @@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # That should result in a single db query to lookup self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) - def test_invalidate_cache_by_room_id(self): + def test_invalidate_cache_by_room_id(self) -> None: """ Test to make sure that all events associated with the given `(room_id,)` are invalidated in the `have_seen_event` cache. @@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") @@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # Reset the event cache so the tests start with it empty self.get_success(self.store._get_event_cache.clear()) - def test_simple(self): + def test_simple(self) -> None: """Test that we cache events that we pull from the DB.""" with LoggingContext("test") as ctx: @@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) - def test_event_ref(self): + def test_event_ref(self) -> None: """Test that we reuse events that are still in memory but have fallen out of the cache, rather than requesting them from the DB. """ @@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0) - def test_dedupe(self): + def test_dedupe(self) -> None: """Test that if we request the same event multiple times we only pull it out once. """ @@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" @@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") @@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): unblock: "Deferred[None]" = Deferred() original_runWithConnection = self.store.db_pool.runWithConnection - async def runWithConnection(*args, **kwargs): + # Don't bother with the types here, we just pass into the original function. + async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def] await unblock return await original_runWithConnection(*args, **kwargs) @@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1) self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0) - def test_first_get_event_cancelled(self): + def test_first_get_event_cancelled(self) -> None: """Test cancellation of the first `get_event` call sharing a database fetch. The first `get_event` call is the one which initiates the fetch. We expect the @@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): # The second `get_event` call should complete successfully. self.get_success(get_event2) - def test_second_get_event_cancelled(self): + def test_second_get_event_cancelled(self) -> None: """Test cancellation of the second `get_event` call sharing a database fetch.""" with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): # Cancel the second `get_event` call. diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 3cc2a58d8d..56cb49d9b5 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -15,18 +15,20 @@ from twisted.internet import defer, reactor from twisted.internet.base import ReactorBase from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS +from synapse.util import Clock from tests import unittest class LockTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def test_acquire_contention(self): + def test_acquire_contention(self) -> None: # Track the number of tasks holding the lock. # Should be at most 1. in_lock = 0 @@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase): release_lock: "Deferred[None]" = Deferred() - async def task(): + async def task() -> None: nonlocal in_lock nonlocal max_in_lock @@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase): # At most one task should have held the lock at a time. self.assertEqual(max_in_lock, 1) - def test_simple_lock(self): + def test_simple_lock(self) -> None: """Test that we can take out a lock and that while we hold it nobody else can take it out. """ @@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase): self.get_success(lock3.__aenter__()) self.get_success(lock3.__aexit__(None, None, None)) - def test_maintain_lock(self): + def test_maintain_lock(self) -> None: """Test that we don't time out locks while they're still active""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) @@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase): self.get_success(lock.__aexit__(None, None, None)) - def test_timeout_lock(self): + def test_timeout_lock(self) -> None: """Test that we time out locks if they're not updated for ages""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) @@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase): self.assertFalse(self.get_success(lock.is_still_valid())) - def test_drop(self): + def test_drop(self) -> None: """Test that dropping the context manager means we stop renewing the lock""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) @@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase): lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) self.assertIsNotNone(lock2) - def test_shutdown(self): + def test_shutdown(self) -> None: """Test that shutting down Synapse releases the locks""" # Acquire two locks lock = self.get_success(self.store.try_acquire_lock("name", "key1")) diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index c4f12d81d7..68026e2830 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") @@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): table: str, receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], - ): + ) -> None: """Test that the background update to uniqueify non-thread receipts in the given receipts table works properly. @@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): f"Background update did not remove all duplicate receipts from {table}", ) - def test_background_receipts_linearized_unique_index(self): + def test_background_receipts_linearized_unique_index(self) -> None: """Test that the background update to uniqueify non-thread receipts in `receipts_linearized` works properly. """ @@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): }, ) - def test_background_receipts_graph_unique_index(self): + def test_background_receipts_graph_unique_index(self) -> None: """Test that the background update to uniqueify non-thread receipts in `receipts_graph` works properly. """ diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 1edb619630..7d961fac64 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -14,10 +14,14 @@ import json +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import RoomTypes from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage.databases.main.room import _BackgroundUpdates +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") @@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): return room_id - def test_background_populate_rooms_creator_column(self): + def test_background_populate_rooms_creator_column(self) -> None: """Test that the background update to populate the rooms creator column works properly. """ @@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(room_creator_after, self.user_id) - def test_background_add_room_type_column(self): + def test_background_add_room_type_column(self) -> None: """Test that the background update to populate the `room_type` column in `room_stats_state` works properly. """ diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 09cb06d614..8bbf936ae9 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): {(1, "user1", "hello"), (2, "user2", "bleb")}, ) - def test_simple_update_many(self): + def test_simple_update_many(self) -> None: """ simple_update_many performs many updates at once. """ diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 72bf5b3d31..1bfd11ceae 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -14,13 +14,17 @@ from typing import Iterable, Optional, Set +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import AccountDataTypes +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest class IgnoredUsersTestCase(unittest.HomeserverTestCase): - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main self.user = "@user:test" @@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): expected_ignored_user_ids, ) - def test_ignoring_users(self): + def test_ignoring_users(self) -> None: """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") self.assert_ignored(self.user, {"@other:test", "@another:remote"}) @@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # Check the removed user. self.assert_ignorers("@another:remote", {self.user}) - def test_caching(self): + def test_caching(self) -> None: """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") @@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) - def test_invalid_data(self): + def test_invalid_data(self) -> None: """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1047ed09c8..5e1324a169 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError from synapse.events import EventBase from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, @@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): - def setUp(self): + def setUp(self) -> None: super(ApplicationServiceStoreTestCase, self).setUp() self.as_yaml_files: List[str] = [] @@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): super(ApplicationServiceStoreTestCase, self).tearDown() - def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: + def _add_appservice( + self, as_token: str, id: str, url: str, hs_token: str, sender: str + ) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): database, make_conn(db_config, self.engine, "test"), self.hs ) - def _add_service(self, url, as_token, id) -> None: + def _add_service(self, url: str, as_token: str, id: str) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state(self, id: str, state: ApplicationServiceState): + def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred: return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_state(as_id, state) VALUES(?,?)" @@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (id, state.value), ) - def _insert_txn(self, as_id, txn_id, events): + def _insert_txn( + self, as_id: str, txn_id: int, events: List[Mock] + ) -> "defer.Deferred[None]": return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " @@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: DatabasePool, db_conn, hs) -> None: + def __init__( + self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer + ) -> None: super().__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): - def _write_config(self, suffix, **kwargs) -> str: + def _write_config(self, suffix: str, **kwargs: str) -> str: vals = { "id": "id" + suffix, "url": "url" + suffix, diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 40e58f8199..256d28e4c9 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict +from typing import Generator from unittest.mock import Mock from twisted.internet import defer @@ -30,7 +30,7 @@ from tests.utils import default_config class SQLBaseStoreTestCase(unittest.TestCase): """Test the "simple" SQL generating methods in SQLBaseStore.""" - def setUp(self): + def setUp(self) -> None: self.db_pool = Mock(spec=["runInteraction"]) self.mock_txn = Mock() self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) @@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_conn.rollback.return_value = None # Our fake runInteraction just runs synchronously inline - def runInteraction(func, *args, **kwargs): + def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_txn, *args, **kwargs)) self.db_pool.runInteraction = runInteraction - def runWithConnection(func, *args, **kwargs): + def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_conn, *args, **kwargs)) self.db_pool.runWithConnection = runWithConnection @@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] @defer.inlineCallbacks - def test_insert_1col(self): + def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_insert_3cols(self): + def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_select_one_1col(self): + def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) @@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_select_one_3col(self): + def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) @@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_select_one_missing(self): + def test_select_one_missing( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None @@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.assertFalse(ret) @defer.inlineCallbacks - def test_select_list(self): + def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 3 self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) @@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_update_one_1col(self): + def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_update_one_4cols(self): + def test_update_one_4cols( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_delete_one(self): + def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index b998ad42d9..d570684c99 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -15,11 +15,16 @@ import os.path from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage import prepare_database +from synapse.storage.types import Cursor from synapse.types import UserID, create_requester +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Test the background update to clean forward extremities table. """ - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() @@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] - def run_background_update(self): + def run_background_update(self) -> None: """Re run the background update to clean up the extremities.""" # Make sure we don't clash with in progress updates. self.assertTrue( @@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): "delete_forward_extremities.sql", ) - def run_delta_file(txn): + def run_delta_file(txn: Cursor) -> None: prepare_database.executescript(txn, schema_path) self.get_success( @@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): (room_id,) ) - def test_soft_failed_extremities_handled_correctly(self): + def test_soft_failed_extremities_handled_correctly(self) -> None: """Test that extremities are correctly calculated in the presence of soft failed events. @@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.assertEqual(latest_event_ids, [event_id_4]) - def test_basic_cleanup(self): + def test_basic_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of soft failed events. @@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(latest_event_ids, [event_id_b]) - def test_chain_of_fail_cleanup(self): + def test_chain_of_fail_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of soft failed events. @@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(latest_event_ids, [event_id_b]) - def test_forked_graph_cleanup(self): + def test_forked_graph_cleanup(self) -> None: r"""Test that extremities are correctly calculated in the presence of soft failed events. @@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["cleanup_extremities_with_dummy_events"] = True return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() @@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.event_creator = homeserver.get_event_creation_handler() homeserver.config.consent.user_consent_version = self.CONSENT_VERSION - def test_send_dummy_event(self): + def test_send_dummy_event(self) -> None: self._create_extremity_rich_graph() # Pump the reactor repeatedly so that the background updates have a @@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) - def test_send_dummy_events_when_insufficient_power(self): + def test_send_dummy_events_when_insufficient_power(self) -> None: self._create_extremity_rich_graph() # Criple power levels self.helper.send_state( @@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) - def test_expiry_logic(self): + def test_expiry_logic(self) -> None: """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() expires old entries correctly. """ @@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): 0, ) - def _create_extremity_rich_graph(self): + def _create_extremity_rich_graph(self) -> None: """Helper method to create bushy graph on demand""" event_id_start = self.create_and_send_event(self.room_id, self.user) @@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): ) self.assertEqual(len(latest_event_ids), 50) - def _enable_consent_checking(self): + def _enable_consent_checking(self) -> None: """Helper method to enable consent checking""" self.event_creator._block_events_without_consent_error = "No consent from user" consent_uri_builder = Mock() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index a9af1babed..81e4e596e4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -13,15 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client import login +from synapse.server import HomeServer from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.types import UserID +from synapse.util import Clock from tests import unittest from tests.server import make_request @@ -30,14 +35,10 @@ from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs - - def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastores().main + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main - def test_insert_new_client_ip(self): + def test_insert_new_client_ip(self) -> None: self.reactor.advance(12345678) user_id = "@user:id" @@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) - def test_insert_new_client_ip_none_device_id(self): + def test_insert_new_client_ip_none_device_id(self) -> None: """ An insert with a device ID of NULL will not create a new entry, but update an existing entry in the user_ips table. @@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) @parameterized.expand([(False,), (True,)]) - def test_get_last_client_ip_by_device(self, after_persisting: bool): + def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None: """Test `get_last_client_ip_by_device` for persisted and unpersisted data""" self.reactor.advance(12345678) @@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): }, ) - def test_get_last_client_ip_by_device_combined_data(self): + def test_get_last_client_ip_by_device_combined_data(self) -> None: """Test that `get_last_client_ip_by_device` combines persisted and unpersisted data together correctly """ @@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) @parameterized.expand([(False,), (True,)]) - def test_get_user_ip_and_agents(self, after_persisting: bool): + def test_get_user_ip_and_agents(self, after_persisting: bool) -> None: """Test `get_user_ip_and_agents` for persisted and unpersisted data""" self.reactor.advance(12345678) @@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) - def test_get_user_ip_and_agents_combined_data(self): + def test_get_user_ip_and_agents_combined_data(self) -> None: """Test that `get_user_ip_and_agents` combines persisted and unpersisted data together correctly """ @@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) - def test_disabled_monthly_active_user(self): + def test_disabled_monthly_active_user(self) -> None: user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertFalse(active) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) - def test_adding_monthly_active_user_when_full(self): + def test_adding_monthly_active_user_when_full(self) -> None: lots_of_users = 100 user_id = "@user:server" @@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertFalse(active) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) - def test_adding_monthly_active_user_when_space(self): + def test_adding_monthly_active_user_when_space(self) -> None: user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertTrue(active) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) - def test_updating_monthly_active_user_when_space(self): + def test_updating_monthly_active_user_when_space(self) -> None: user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) @@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) - def test_devices_last_seen_bg_update(self): + def test_devices_last_seen_bg_update(self) -> None: # First make sure we have completed all updates. self.wait_for_background_updates() @@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) - def test_old_user_ips_pruned(self): + def test_old_user_ips_pruned(self) -> None: # First make sure we have completed all updates. self.wait_for_background_updates() @@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertEqual(result, []) # But we should still get the correct values for the device - result = self.get_success( + result2 = self.get_success( self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, device_id)] + r = result2[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, @@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs - - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) - def test_request_with_xforwarded(self): + def test_request_with_xforwarded(self) -> None: """ The IP in X-Forwarded-For is entered into the client IPs table. """ @@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): {"request": XForwardedForRequest}, ) - def test_request_from_getPeer(self): + def test_request_from_getPeer(self) -> None: """ The IP returned by getPeer is entered into the client IPs table, if there's no X-Forwarded-For header. """ self._runtest({}, "127.0.0.1", {}) - def _runtest(self, headers, expected_ip, make_request_args): + def _runtest( + self, + headers: Dict[bytes, bytes], + expected_ip: str, + make_request_args: Dict[str, Any], + ) -> None: device_id = "bleb" access_token = self.login("bob", "abc123", device_id=device_id) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index a40fc20ef9..543cce6b3e 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -31,7 +31,7 @@ from tests import unittest class TupleComparisonClauseTestCase(unittest.TestCase): - def test_native_tuple_comparison(self): + def test_native_tuple_comparison(self) -> None: clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 8e7db2c4ec..f03807c8f9 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -12,17 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, List, Tuple + +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors from synapse.api.constants import EduTypes +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def add_device_change(self, user_id, device_ids, host): + def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: """Add a device list change for the given device to `device_lists_outbound_pokes` table. """ @@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase): ) ) - def test_store_new_device(self): + def test_store_new_device(self) -> None: self.get_success( self.store.store_device("user_id", "device_id", "display_name") ) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertDictContainsSubset( { "user_id": "user_id", @@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase): res, ) - def test_get_devices_by_user(self): + def test_get_devices_by_user(self) -> None: self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) @@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase): res["device2"], ) - def test_count_devices_by_users(self): + def test_count_devices_by_users(self) -> None: self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) @@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase): ) self.assertEqual(3, res) - def test_get_device_updates_by_remote(self): + def test_get_device_updates_by_remote(self) -> None: device_ids = ["device_id1", "device_id2"] # Add two device updates with sequential `stream_id`s @@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - def test_get_device_updates_by_remote_can_limit_properly(self): + def test_get_device_updates_by_remote_can_limit_properly(self) -> None: """ Tests that `get_device_updates_by_remote` returns an appropriate stream_id to resume fetching from (without skipping any results). @@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase): ) self.assertEqual(device_updates, []) - def _check_devices_in_updates(self, expected_device_ids, device_updates): + def _check_devices_in_updates( + self, + expected_device_ids: Collection[str], + device_updates: List[Tuple[str, JsonDict]], + ) -> None: """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) @@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase): } self.assertEqual(received_device_ids, set(expected_device_ids)) - def test_update_device(self): + def test_update_device(self) -> None: self.get_success( self.store.store_device("user_id", "device_id", "display_name 1") ) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 1", res["display_name"]) # do a no-op first self.get_success(self.store.update_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase): # check it worked res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 2", res["display_name"]) - def test_update_unknown_device(self): + def test_update_unknown_device(self) -> None: exc = self.get_failure( self.store.update_device( "user_id", "unknown_device_id", new_display_name="display_name 2" diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 20bf3ca17b..8bedc6bdf3 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -12,19 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer from synapse.types import RoomAlias, RoomID +from synapse.util import Clock from tests.unittest import HomeserverTestCase class DirectoryStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") - def test_room_to_alias(self): + def test_room_to_alias(self) -> None: self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] @@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) - def test_alias_to_room(self): + def test_alias_to_room(self) -> None: self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] @@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): (self.get_success(self.store.get_association_from_room_alias(self.alias))), ) - def test_delete_alias(self): + def test_delete_alias(self) -> None: self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index fb96ab3a2f..9cb326d90a 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer from synapse.storage.databases.main.e2e_room_keys import RoomKey +from synapse.util import Clock from tests import unittest @@ -26,12 +30,12 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.store = hs.get_datastores().main return hs - def test_room_keys_version_delete(self): + def test_room_keys_version_delete(self) -> None: # test that deleting a room key backup deletes the keys version1 = self.get_success( self.store.create_e2e_room_keys_version( diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 0f04493ad0..5fde3b9c78 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.util import Clock + from tests.unittest import HomeserverTestCase class EndToEndKeyStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def test_key_without_device_name(self): + def test_key_without_device_name(self) -> None: now = 1470174257070 json = {"key": "value"} @@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): dev = res["user"]["device"] self.assertDictContainsSubset(json, dev) - def test_reupload_key(self): + def test_reupload_key(self) -> None: now = 1470174257070 json = {"key": "value"} @@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): ) self.assertFalse(changed) - def test_get_key_with_device_name(self): + def test_get_key_with_device_name(self) -> None: now = 1470174257070 json = {"key": "value"} @@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) - def test_multiple_devices(self): + def test_multiple_devices(self) -> None: now = 1470174257070 self.get_success(self.store.store_device("user1", "device1", None)) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index de9f4af2de..c070278db8 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -14,6 +14,7 @@ from typing import Dict, List, Set, Tuple +from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from synapse.api.constants import EventTypes @@ -22,18 +23,22 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.events import _LinkMap +from synapse.storage.types import Cursor from synapse.types import create_requester +from synapse.util import Clock from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self._next_stream_ordering = 1 - def test_simple(self): + def test_simple(self) -> None: """Test that the example in `docs/auth_chain_difference_algorithm.md` works. """ @@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase): ), ) - def test_out_of_order_events(self): + def test_out_of_order_events(self) -> None: """Test that we handle persisting events that we don't have the full auth chain for yet (which should only happen for out of band memberships). """ @@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase): def persist( self, events: List[EventBase], - ): + ) -> None: """Persist the given events and check that the links generated match those given. """ @@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase): e.internal_metadata.stream_ordering = self._next_stream_ordering self._next_stream_ordering += 1 - def _persist(txn): + def _persist(txn: LoggingTransaction) -> None: # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( @@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase): class LinkMapTestCase(unittest.TestCase): - def test_simple(self): + def test_simple(self) -> None: """Basic tests for the LinkMap.""" link_map = _LinkMap() @@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") @@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Delete the chain cover info. - def _delete_tables(txn): + def _delete_tables(txn: Cursor) -> None: txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") @@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): return room_id, [state1, state2] - def test_background_update_single_room(self): + def test_background_update_single_room(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ @@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) - def test_background_update_multiple_rooms(self): + def test_background_update_multiple_rooms(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ @@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) - def test_background_update_single_large_room(self): + def test_background_update_single_large_room(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ @@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) - def test_background_update_multiple_large_room(self): + def test_background_update_multiple_large_room(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 853db930d6..7fd3e01364 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,7 +13,7 @@ # limitations under the License. import datetime -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, cast import attr from parameterized import parameterized @@ -26,11 +26,12 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersion, ) -from synapse.events import _EventInternalMetadata +from synapse.events import EventBase, _EventInternalMetadata from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import Clock, json_encoder @@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def test_get_prev_events_for_room(self): + def test_get_prev_events_for_room(self) -> None: room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities - def insert_event(txn, i): + def insert_event(txn: Cursor, i: int) -> None: event_id = "$event_%i:local" % i txn.execute( @@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for i in range(0, 10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) - def test_get_rooms_with_many_extremities(self): + def test_get_rooms_with_many_extremities(self) -> None: room1 = "#room1" room2 = "#room2" room3 = "#room3" - def insert_event(txn, i, room_id): + def insert_event(txn: Cursor, i: int, room_id: str) -> None: event_id = "$event_%i:local" % i txn.execute( ( @@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # | | # K J - auth_graph = { + auth_graph: Dict[str, List[str]] = { "a": ["e"], "b": ["e"], "c": ["g", "i"], @@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Mark the room as maybe having a cover index. - def store_room(txn): + def store_room(txn: LoggingTransaction) -> None: self.store.db_pool.simple_insert_txn( txn, "rooms", @@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn): + def insert_event(txn: LoggingTransaction) -> None: stream_ordering = 0 for event_id in auth_graph: @@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, [ - FakeEvent(event_id, room_id, auth_graph[event_id]) + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) for event_id in auth_graph ], ) @@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): return room_id @parameterized.expand([(True,), (False,)]) - def test_auth_chain_ids(self, use_chain_cover_index: bool): + def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None: room_id = self._setup_auth_chain(use_chain_cover_index) # a and b have the same auth chain. @@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.assertCountEqual(auth_chain_ids, ["i", "j"]) @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def test_auth_difference(self, use_chain_cover_index: bool) -> None: room_id = self._setup_auth_chain(use_chain_cover_index) # Now actually test that various combinations give the right result: @@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertSetEqual(difference, set()) - def test_auth_difference_partial_cover(self): + def test_auth_difference_partial_cover(self) -> None: """Test that we correctly handle rooms where not all events have a chain cover calculated. This can happen in some obscure edge cases, including during the background update that calculates the chain cover for old @@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # | | # K J - auth_graph = { + auth_graph: Dict[str, List[str]] = { "a": ["e"], "b": ["e"], "c": ["g", "i"], @@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn): + def insert_event(txn: LoggingTransaction) -> None: # First insert the room and mark it as having a chain cover. self.store.db_pool.simple_insert_txn( txn, @@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, [ - FakeEvent(event_id, room_id, auth_graph[event_id]) + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) for event_id in auth_graph if event_id != "b" ], @@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, - [FakeEvent("b", room_id, auth_graph["b"])], + [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], ) self.store.db_pool.simple_update_txn( @@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): @parameterized.expand( [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()] ) - def test_prune_inbound_federation_queue(self, room_version: RoomVersion): + def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None: """Test that pruning of inbound federation queues work""" room_id = "some_room_id" @@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): stream_ordering += 1 - def populate_db(txn: LoggingTransaction): + def populate_db(txn: LoggingTransaction) -> None: # Insert the room to satisfy the foreign key constraint of # `event_failed_pull_attempts` self.store.db_pool.simple_insert_txn( @@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) - def test_get_backfill_points_in_room(self): + def test_get_backfill_points_in_room(self) -> None: """ Test to make sure only backfill points that are older and come before the `current_depth` are returned. @@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_backfill_points_in_room_excludes_events_we_have_attempted( self, - ): + ) -> None: """ Test to make sure that events we have attempted to backfill (and within backoff timeout duration) do not show up as an event to backfill again. @@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( self, - ): + ) -> None: """ Test to make sure after we fake attempt to backfill event "b3" many times, we can see retry and see the "b3" again after the backoff timeout duration @@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "5": 7, } - def populate_db(txn: LoggingTransaction): + def populate_db(txn: LoggingTransaction) -> None: # Insert the room to satisfy the foreign key constraint of # `event_failed_pull_attempts` self.store.db_pool.simple_insert_txn( @@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) - def test_get_insertion_event_backward_extremities_in_room(self): + def test_get_insertion_event_backward_extremities_in_room(self) -> None: """ Test to make sure only insertion event backward extremities that are older and come before the `current_depth` are returned. @@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( self, - ): + ) -> None: """ Test to make sure that insertion events we have attempted to backfill (and within backoff timeout duration) do not show up as an event to @@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( self, - ): + ) -> None: """ Test to make sure after we fake attempt to backfill event "insertion_eventA" many times, we can see retry and see the @@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["insertion_eventA"]) - def test_get_event_ids_to_not_pull_from_backoff( - self, - ): + def test_get_event_ids_to_not_pull_from_backoff(self) -> None: """ Test to make sure only event IDs we should backoff from are returned. """ @@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( self, - ): + ) -> None: """ Test to make sure no event IDs are returned after the backoff duration has elapsed. @@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.assertEqual(event_ids_to_backoff, []) -@attr.s +@attr.s(auto_attribs=True) class FakeEvent: - event_id = attr.ib() - room_id = attr.ib() - auth_events = attr.ib() + event_id: str + room_id: str + auth_events: List[str] type = "foo" state_key = "foo" internal_metadata = _EventInternalMetadata({}) - def auth_event_ids(self): + def auth_event_ids(self) -> List[str]: return self.auth_events - def is_state(self): + def is_state(self) -> bool: return True diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 6f1135eef4..a91411168c 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class ExtremStatisticsTestCase(HomeserverTestCase): - def test_exposed_to_prometheus(self): + def test_exposed_to_prometheus(self) -> None: """ Forward extremity counts are exposed via Prometheus. """ diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 3ce4f35cb7..05661a537d 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import StateMap +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence self._state_storage_controller = self.hs.get_storage_controllers().state @@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check that the current extremities is the remote event. self.assert_extremities([self.remote_event_1.event_id]) - def persist_event(self, event, state=None): + def persist_event( + self, event: EventBase, state: Optional[StateMap[str]] = None + ) -> None: """Persist the event, with optional state""" context = self.get_success( self.state.compute_event_context( @@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase): ) self.get_success(self._persistence.persist_event(event, context)) - def assert_extremities(self, expected_extremities): + def assert_extremities(self, expected_extremities: List[str]) -> None: """Assert the current extremities for the room""" extremities = self.get_success( self.store.get_prev_events_for_room(self.room_id) ) self.assertCountEqual(extremities, expected_extremities) - def test_prune_gap(self): + def test_prune_gap(self) -> None: """Test that we drop extremities after a gap when we see an event from the same domain. """ @@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) - def test_do_not_prune_gap_if_state_different(self): + def test_do_not_prune_gap_if_state_different(self) -> None: """Test that we don't prune extremities after a gap if the resolved state is different. """ @@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) - def test_prune_gap_if_old(self): + def test_prune_gap_if_old(self) -> None: """Test that we drop extremities after a gap when the previous extremity is "old" """ @@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) - def test_do_not_prune_gap_if_other_server(self): + def test_do_not_prune_gap_if_other_server(self) -> None: """Test that we do not drop extremities after a gap when we see an event from a different domain. """ @@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) - def test_prune_gap_if_dummy_remote(self): + def test_prune_gap_if_dummy_remote(self) -> None: """Test that we drop extremities after a gap when the previous extremity is a local dummy event and only points to remote events. """ @@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) - def test_prune_gap_if_dummy_local(self): + def test_prune_gap_if_dummy_local(self) -> None: """Test that we don't drop extremities after a gap when the previous extremity is a local dummy event and points to local events. """ @@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id, local_message_event_id]) - def test_do_not_prune_gap_if_not_dummy(self): + def test_do_not_prune_gap_if_not_dummy(self) -> None: """Test that we do not drop extremities after a gap when the previous extremity is not a dummy event. """ @@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main - def test_remote_user_rooms_cache_invalidated(self): + def test_remote_user_rooms_cache_invalidated(self) -> None: """Test that if the server leaves a room the `get_rooms_for_user` cache is invalidated for remote users. """ @@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) self.assertEqual(set(rooms), set()) - def test_room_remote_user_cache_invalidated(self): + def test_room_remote_user_cache_invalidated(self) -> None: """Test that if the server leaves a room the `get_users_in_room` cache is invalidated for remote users. """ diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 9059095525..aa4b5bd3b1 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -13,6 +13,7 @@ # limitations under the License. import signedjson.key +import signedjson.types import unpaddedbase64 from twisted.internet.defer import Deferred @@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -def decode_verify_key_base64(key_id: str, key_base64: str): +def decode_verify_key_base64( + key_id: str, key_base64: str +) -> signedjson.types.VerifyKey: key_bytes = unpaddedbase64.decode_base64(key_base64) return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) @@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): - def test_get_server_verify_keys(self): + def test_get_server_verify_keys(self) -> None: store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" @@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) - def test_cache(self): + def test_cache(self) -> None: """Check that updates correctly invalidate the cache.""" store = self.hs.get_datastores().main diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index c55c4db970..2827738379 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.reactor.advance(FORTY_DAYS) @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) - def test_initialise_reserved_users(self): + def test_initialise_reserved_users(self) -> None: threepids = self.hs.config.server.mau_limits_reserved_threepids # register three users, of which two have reserved 3pids, and a third @@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): active_count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(active_count, 3) - def test_can_insert_and_count_mau(self): + def test_can_insert_and_count_mau(self) -> None: count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 0) @@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 1) - def test_appservice_user_not_counted_in_mau(self): + def test_appservice_user_not_counted_in_mau(self) -> None: self.get_success( self.store.register_user( user_id="@appservice_user:server", appservice_id="wibble" @@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 0) - def test_user_last_seen_monthly_active(self): + def test_user_last_seen_monthly_active(self) -> None: user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" @@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertIsNone(result) @override_config({"max_mau_value": 5}) - def test_reap_monthly_active_users(self): + def test_reap_monthly_active_users(self) -> None: initial_users = 10 for i in range(initial_users): self.get_success( @@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Note that below says mau_limit (no s), this is the name of the config # value, although it gets stored on the config object as mau_limits. @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) - def test_reap_monthly_active_users_reserved_users(self): + def test_reap_monthly_active_users_reserved_users(self) -> None: """Tests that reaping correctly handles reaping where reserved users are present""" threepids = self.hs.config.server.mau_limits_reserved_threepids @@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, self.hs.config.server.max_mau_value) - def test_populate_monthly_users_is_guest(self): + def test_populate_monthly_users_is_guest(self) -> None: # Test that guest users are not added to mau list user_id = "@user_id:host" @@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() - def test_populate_monthly_users_should_update(self): + def test_populate_monthly_users_should_update(self) -> None: self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] @@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() - def test_populate_monthly_users_should_not_update(self): + def test_populate_monthly_users_should_not_update(self) -> None: self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] @@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() - def test_get_reserved_real_user_account(self): + def test_get_reserved_real_user_account(self) -> None: # Test no reserved users, or reserved threepids users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), 0) @@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) - def test_support_user_not_add_to_mau_limits(self): + def test_support_user_not_add_to_mau_limits(self) -> None: support_user_id = "@support:test" count = self.get_success(self.store.get_monthly_active_count()) @@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config( {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} ) - def test_track_monthly_users_without_cap(self): + def test_track_monthly_users_without_cap(self) -> None: count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(0, count) @@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEqual(2, count) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) - def test_no_users_when_not_tracking(self): + def test_no_users_when_not_tracking(self) -> None: self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.store.upsert_monthly_active_user.assert_not_called() - def test_get_monthly_active_count_by_service(self): + def test_get_monthly_active_count_by_service(self) -> None: appservice1_user1 = "@appservice1_user1:example.com" appservice1_user2 = "@appservice1_user2:example.com" @@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEqual(result[service2], 1) self.assertEqual(result[native], 1) - def test_get_monthly_active_users_by_service(self): + def test_get_monthly_active_users_by_service(self) -> None: # (No users, no filtering) -> empty result result = self.get_success(self.store.get_monthly_active_users_by_service()) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 9c1182ed16..010cc74c31 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import NotFoundError, SynapseError from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase): user_id = "@red:server" servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main self._storage_controllers = self.hs.get_storage_controllers() - def test_purge_history(self): + def test_purge_history(self) -> None: """ Purging a room history will delete everything before the topological point. """ @@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase): self.get_failure(self.store.get_event(third["event_id"]), NotFoundError) self.get_success(self.store.get_event(last["event_id"])) - def test_purge_history_wont_delete_extrems(self): + def test_purge_history_wont_delete_extrems(self) -> None: """ Purging a room history will delete everything before the topological point. """ @@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) + assert token.topological is not None event = f"t{token.topological + 1}-{token.stream + 1}" # Purge everything before this topological token @@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase): self.get_success(self.store.get_event(third["event_id"])) self.get_success(self.store.get_event(last["event_id"])) - def test_purge_room(self): + def test_purge_room(self) -> None: """ Purging a room will delete everything about it. """ diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 81253d0361..d8d84152dc 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -14,8 +14,12 @@ from typing import Collection, Optional +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import ReceiptTypes +from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock from tests.test_utils.event_injection import create_event from tests.unittest import HomeserverTestCase @@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test" class ReceiptTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver) -> None: + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: super().prepare(reactor, clock, homeserver) self.store = homeserver.get_datastores().main @@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase): ) self.assertEqual(res, {}) - res = self.get_last_unthreaded_receipt( + res2 = self.get_last_unthreaded_receipt( [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) - self.assertEqual(res, None) + self.assertIsNone(res2) def test_get_receipts_for_user(self) -> None: # Send some events into the first room diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 6c4e63b77c..df4740f9d9 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -11,27 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, cast from canonicaljson import json +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.types import RoomID, UserID +from synapse.events import EventBase, _EventInternalMetadata +from synapse.events.builder import EventBuilder +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomID, UserID +from synapse.util import Clock from tests import unittest from tests.utils import create_room class RedactionTestCase(unittest.HomeserverTestCase): - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config["redaction_retention_period"] = "30d" return config - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self._storage = hs.get_storage_controllers() + storage = hs.get_storage_controllers() + assert storage.persistence is not None + self._persistence = storage.persistence self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.depth = 1 - def inject_room_member( + def inject_room_member( # type: ignore[override] self, - room, - user, - membership, - replaces_state=None, - extra_content: Optional[dict] = None, - ): + room: RoomID, + user: UserID, + membership: str, + extra_content: Optional[JsonDict] = None, + ) -> EventBase: content = {"membership": membership} content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( @@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self._storage.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) return event - def inject_message(self, room, user, body): + def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase: self.depth += 1 builder = self.event_builder_factory.for_room_version( @@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self._storage.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) return event - def inject_redaction(self, room, event_id, user, reason): + def inject_redaction( + self, room: RoomID, event_id: str, user: UserID, reason: str + ) -> EventBase: builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { @@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self._storage.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) return event - def test_redact(self): + def test_redact(self) -> None: self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) msg_event = self.inject_message(self.room1, self.u_alice, "t") @@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): event.unsigned["redacted_because"], ) - def test_redact_join(self): + def test_redact_join(self) -> None: self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) msg_event = self.inject_room_member( @@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): event.unsigned["redacted_because"], ) - def test_circular_redaction(self): + def test_circular_redaction(self) -> None: redaction_event_id1 = "$redaction1_id:test" redaction_event_id2 = "$redaction2_id:test" class EventIdManglingBuilder: - def __init__(self, base_builder, event_id): + def __init__(self, base_builder: EventBuilder, event_id: str): self._base_builder = base_builder self._event_id = event_id @@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase): prev_event_ids: List[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, - ): + ) -> EventBase: built_event = await self._base_builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids ) - built_event._event_id = self._event_id + built_event._event_id = self._event_id # type: ignore[attr-defined] built_event._dict["event_id"] = self._event_id assert built_event.event_id == self._event_id return built_event @property - def room_id(self): + def room_id(self) -> str: return self._base_builder.room_id @property - def type(self): + def type(self) -> str: return self._base_builder.type @property - def internal_metadata(self): + def internal_metadata(self) -> _EventInternalMetadata: return self._base_builder.internal_metadata event_1, context_1 = self.get_success( self.event_creation_handler.create_new_client_event( - EventIdManglingBuilder( - self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Redaction, - "sender": self.u_alice.to_string(), - "room_id": self.room1.to_string(), - "content": {"reason": "test"}, - "redacts": redaction_event_id2, - }, + cast( + EventBuilder, + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, ), - redaction_event_id1, ) ) ) - self.get_success(self._storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( - EventIdManglingBuilder( - self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Redaction, - "sender": self.u_alice.to_string(), - "room_id": self.room1.to_string(), - "content": {"reason": "test"}, - "redacts": redaction_event_id1, - }, + cast( + EventBuilder, + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, ), - redaction_event_id2, ) ) ) - self.get_success(self._storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): fetched.unsigned["redacted_because"].event_id, redaction_event_id2 ) - def test_redact_censor(self): + def test_redact_censor(self) -> None: """Test that a redacted event gets censored in the DB after a month""" self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.assert_dict({"content": {}}, json.loads(event_json)) - def test_redact_redaction(self): + def test_redact_redaction(self) -> None: """Tests that we can redact a redaction and can fetch it again.""" self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.store.get_event(first_redact_event.event_id, allow_none=True) ) - def test_store_redacted_redaction(self): + def test_store_redacted_redaction(self) -> None: """Tests that we can store a redacted redaction.""" self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success( - self._storage.persistence.persist_event(redaction_event, context) - ) + self.get_success(self._persistence.persist_event(redaction_event, context)) # Now lets jump to the future where we have censored the redaction event # in the DB. diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index 0baa54312e..966aafea6f 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -14,10 +14,15 @@ from typing import List from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.app.generic_worker import GenericWorkerServer +from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database from synapse.storage.schema import SCHEMA_VERSION +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]: class WorkerSchemaTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs - def default_config(self): + def default_config(self) -> JsonDict: conf = super().default_config() # Mark this as a worker app. @@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase): return conf - def test_rolling_back(self): + def test_rolling_back(self) -> None: """Test that workers can start if the DB is a newer schema version""" db_pool = self.hs.get_datastores().main.db_pool @@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase): prepare_database(db_conn, db_pool.engine, self.hs.config) - def test_not_upgraded_old_schema_version(self): + def test_not_upgraded_old_schema_version(self) -> None: """Test that workers don't start if the DB has an older schema version""" db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( @@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): with self.assertRaises(PrepareDatabaseException): prepare_database(db_conn, db_pool.engine, self.hs.config) - def test_not_upgraded_current_schema_version_with_outstanding_deltas(self): + def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None: """ Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3405efb6a8..71ec74eadc 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer from synapse.types import RoomAlias, RoomID, UserID +from synapse.util import Clock from tests.unittest import HomeserverTestCase class RoomStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table self.store = hs.get_datastores().main @@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase): ) ) - def test_get_room(self): + def test_get_room(self) -> None: + res = self.get_success(self.store.get_room(self.room.to_string())) + assert res is not None self.assertDictContainsSubset( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "is_public": True, }, - (self.get_success(self.store.get_room(self.room.to_string()))), + res, ) - def test_get_room_unknown_room(self): + def test_get_room_unknown_room(self) -> None: self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) - def test_get_room_with_stats(self): + def test_get_room_with_stats(self) -> None: + res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) + assert res is not None self.assertDictContainsSubset( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "public": True, }, - (self.get_success(self.store.get_room_with_stats(self.room.to_string()))), + res, ) - def test_get_room_with_stats_unknown_room(self): + def test_get_room_with_stats_unknown_room(self) -> None: self.assertIsNone( - (self.get_success(self.store.get_room_with_stats("!uknown:test"))), + self.get_success(self.store.get_room_with_stats("!uknown:test")) ) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index ef850daa73..14d872514d 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase): room.register_servlets, ] - def test_null_byte(self): + def test_null_byte(self) -> None: """ Postgres/SQLite don't like null bytes going into the search tables. Internally we replace those with a space. @@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase): if isinstance(store.database_engine, PostgresEngine): self.assertIn("alice", result.get("highlights")) - def test_non_string(self): + def test_non_string(self) -> None: """Test that non-string `value`s are not inserted into `event_search`. This is particularly important when using sqlite, since a sqlite column can hold @@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase): self.assertEqual(f.value.code, 404) @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") - def test_sqlite_non_string_deletion_background_update(self): + def test_sqlite_non_string_deletion_background_update(self) -> None: """Test the background update to delete bad rows from `event_search`.""" store = self.hs.get_datastores().main @@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase): "results array length should match count", ) - def test_postgres_web_search_for_phrase(self): + def test_postgres_web_search_for_phrase(self) -> None: """ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. This test is skipped unless the postgres instance supports websearch_to_tsquery. @@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase): self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) - def test_sqlite_search(self): + def test_sqlite_search(self) -> None: """ Test sqlite searching for phrases. """ diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5564161750..d4e6d4236c 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -16,10 +16,15 @@ import logging from frozendict import frozendict +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.server import HomeServer from synapse.storage.state import StateFilter -from synapse.types import RoomID, UserID +from synapse.types import JsonDict, RoomID, StateMap, UserID +from synapse.util import Clock from tests.unittest import HomeserverTestCase, TestCase @@ -27,7 +32,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state @@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase): ) ) - def inject_state_event(self, room, sender, typ, state_key, content): + def inject_state_event( + self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict + ) -> EventBase: builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { @@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) + assert self.storage.persistence is not None self.get_success(self.storage.persistence.persist_event(event, context)) return event - def assertStateMapEqual(self, s1, s2): + def assertStateMapEqual( + self, s1: StateMap[EventBase], s2: StateMap[EventBase] + ) -> None: for t in s1: # just compare event IDs for simplicity self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) - def test_get_state_groups_ids(self): + def test_get_state_groups_ids(self) -> None: e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) + self.storage.state.get_state_groups_ids( + self.room.to_string(), [e2.event_id] + ) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase): {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) - def test_get_state_groups(self): + def test_get_state_groups(self) -> None: e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = self.get_success( - self.storage.state.get_state_groups(self.room, [e2.event_id]) + self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) - def test_get_state_for_event(self): + def test_get_state_for_event(self) -> None: # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) @@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase): class StateFilterDifferenceTestCase(TestCase): def assert_difference( self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter - ): + ) -> None: self.assertEqual( minuend.approx_difference(subtrahend), expected, f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", ) - def test_state_filter_difference_no_include_other_minus_no_include_other(self): + def test_state_filter_difference_no_include_other_minus_no_include_other( + self, + ) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), both a and b do not have the @@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_include_other_minus_no_include_other(self): + def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), only a has the include_others flag set. @@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_include_other_minus_include_other(self): + def test_state_filter_difference_include_other_minus_include_other(self) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), both a and b have the include_others @@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_no_include_other_minus_include_other(self): + def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), only b has the include_others flag set. @@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_simple_cases(self): + def test_state_filter_difference_simple_cases(self) -> None: """ Tests some very simple cases of the StateFilter approx_difference, that are not explicitly tested by the more in-depth tests. @@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase): class StateFilterTestCase(TestCase): - def test_return_expanded(self): + def test_return_expanded(self) -> None: """ Tests the behaviour of the return_expanded() function that expands StateFilters to include more state types (for the sake of cache hit rate). diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 34fa810cf6..bc090ebce0 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -14,11 +14,15 @@ from typing import List +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config["experimental_features"] = {"msc3874_enabled": True} return config - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase): return [ev.event_id for ev in events] - def test_filter_relation_senders(self): + def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) @@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase): chunk = self._filter_messages(filter) self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) - def test_filter_relation_type(self): + def test_filter_relation_type(self) -> None: # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) @@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase): chunk = self._filter_messages(filter) self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) - def test_filter_relation_senders_and_type(self): + def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { "related_by_senders": [self.second_user_id], @@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase): chunk = self._filter_messages(filter) self.assertEqual(chunk, [self.event_id_1]) - def test_duplicate_relation(self): + def test_duplicate_relation(self) -> None: """An event should only be returned once if there are multiple relations to it.""" self.helper.send_event( room_id=self.room_id, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index e05daa285e..db9ee9955e 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -12,17 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer from synapse.storage.databases.main.transactions import DestinationRetryTimings +from synapse.util import Clock from synapse.util.retryutils import MAX_RETRY_INTERVAL from tests.unittest import HomeserverTestCase class TransactionStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main - def test_get_set_transactions(self): + def test_get_set_transactions(self) -> None: """Tests that we can successfully get a non-existent entry for destination retries, as well as testing tht we can set and get correctly. @@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase): r, ) - def test_initial_set_transactions(self): + def test_initial_set_transactions(self) -> None: """Tests that we can successfully set the destination retries (there was a bug around invalidating the cache that broke this) """ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) - def test_large_destination_retry(self): + def test_large_destination_retry(self) -> None: d = self.store.set_destination_retry_timings( "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL ) self.get_success(d) - d = self.store.get_destination_retry_timings("example.com") - self.get_success(d) + d2 = self.store.get_destination_retry_timings("example.com") + self.get_success(d2) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index ace82cbf42..15ea4770bd 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -12,21 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.types import Cursor +from synapse.util import Clock + from tests import unittest class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): """Test SQL transaction limit doesn't break transactions.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(db_txn_limit=1000) - def test_config(self): + def test_config(self) -> None: db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) - def test_select(self): - def do_select(txn): + def test_select(self) -> None: + def do_select(txn: Cursor) -> None: txn.execute("SELECT 1") db_pool = self.hs.get_datastores().databases[0] diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index cae14151c0..0e3fc2a77f 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Collection, Dict from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred @@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase): # the results to be returned by the mocked get_partial_state_events self._events_dict: Dict[str, bool] = {} - async def get_partial_state_events(events): + async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]: return {e: self._events_dict[e] for e in events} self.mock_store = mock.Mock(spec_set=["get_partial_state_events"]) @@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker = PartialStateEventsTracker(self.mock_store) - def test_does_not_block_for_full_state_events(self): + def test_does_not_block_for_full_state_events(self) -> None: self._events_dict = {"event1": False, "event2": False} self.successResultOf( @@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase): ["event1", "event2"] ) - def test_blocks_for_partial_state_events(self): + def test_blocks_for_partial_state_events(self) -> None: self._events_dict = {"event1": True, "event2": False} d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) @@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d) - def test_un_partial_state_race(self): + def test_un_partial_state_race(self) -> None: # if the event is un-partial-stated between the initial check and the # registration of the listener, it should not block. self._events_dict = {"event1": True, "event2": False} - async def get_partial_state_events(events): + async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]: res = {e: self._events_dict[e] for e in events} # change the result for next time self._events_dict = {"event1": False, "event2": False} @@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase): ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) ) - def test_un_partial_state_during_get_partial_state_events(self): + def test_un_partial_state_during_get_partial_state_events(self) -> None: # we should correctly handle a call to notify_un_partial_stated during the # second call to get_partial_state_events. self._events_dict = {"event1": True, "event2": False} - async def get_partial_state_events1(events): + async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]: self.mock_store.get_partial_state_events.side_effect = ( get_partial_state_events2 ) return {e: self._events_dict[e] for e in events} - async def get_partial_state_events2(events): + async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]: self.tracker.notify_un_partial_stated("event1") self._events_dict["event1"] = False return {e: self._events_dict[e] for e in events} @@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase): ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) ) - def test_cancellation(self): + def test_cancellation(self) -> None: self._events_dict = {"event1": True, "event2": False} d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) @@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.tracker = PartialCurrentStateTracker(self.mock_store) - def test_does_not_block_for_full_state_rooms(self): + def test_does_not_block_for_full_state_rooms(self) -> None: self.mock_store.is_partial_state_room.return_value = make_awaitable(False) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) - def test_blocks_for_partial_room_state(self): + def test_blocks_for_partial_room_state(self) -> None: self.mock_store.is_partial_state_room.return_value = make_awaitable(True) d = ensureDeferred(self.tracker.await_full_state("room_id")) @@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("room_id") self.successResultOf(d) - def test_un_partial_state_race(self): + def test_un_partial_state_race(self) -> None: # We should correctly handle race between awaiting the state and us # un-partialling the state - async def is_partial_state_room(events): + async def is_partial_state_room(room_id: str) -> bool: self.tracker.notify_un_partial_stated("room_id") return True @@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) - def test_cancellation(self): + def test_cancellation(self) -> None: self.mock_store.is_partial_state_room.return_value = make_awaitable(True) d1 = ensureDeferred(self.tracker.await_full_state("room_id")) -- cgit 1.5.1 From 2a3cd59dd06411a79fb7500970db1b98f0d87695 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 12 Dec 2022 13:21:17 +0100 Subject: Add optional ICU support for user search (#14464) Fixes #13655 This change uses ICU (International Components for Unicode) to improve boundary detection in user search. This change also adds a new dependency on libicu-dev and pkg-config for the Debian packages, which are available in all supported distros. --- changelog.d/14464.feature | 1 + debian/changelog | 7 +++ debian/control | 2 + docker/Dockerfile | 2 + docker/Dockerfile-dhvirtualenv | 2 + poetry.lock | 16 +++++- pyproject.toml | 7 +++ stubs/icu.pyi | 25 +++++++++ synapse/storage/databases/main/user_directory.py | 67 ++++++++++++++++++++++-- tests/storage/test_user_directory.py | 43 +++++++++++++++ 10 files changed, 166 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14464.feature create mode 100644 stubs/icu.pyi (limited to 'tests') diff --git a/changelog.d/14464.feature b/changelog.d/14464.feature new file mode 100644 index 0000000000..688ea32117 --- /dev/null +++ b/changelog.d/14464.feature @@ -0,0 +1 @@ +Improve user search for international display names. diff --git a/debian/changelog b/debian/changelog index 163b7210bf..5d3c4f7d6b 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +matrix-synapse-py3 (1.74.0~rc1) UNRELEASED; urgency=medium + + * New dependency on libicu-dev to provide improved results for user + search. + + -- Synapse Packaging team Tue, 06 Dec 2022 15:28:10 +0000 + matrix-synapse-py3 (1.73.0) stable; urgency=medium * New Synapse release 1.73.0. diff --git a/debian/control b/debian/control index 86f5a66d02..bc628cec08 100644 --- a/debian/control +++ b/debian/control @@ -8,6 +8,8 @@ Build-Depends: dh-virtualenv (>= 1.1), libsystemd-dev, libpq-dev, + libicu-dev, + pkg-config, lsb-release, python3-dev, python3, diff --git a/docker/Dockerfile b/docker/Dockerfile index 185d5bc3d4..7e5123210a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -97,6 +97,8 @@ RUN \ zlib1g-dev \ git \ curl \ + libicu-dev \ + pkg-config \ && rm -rf /var/lib/apt/lists/* diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index 73165f6f85..f3b5b00ce6 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -84,6 +84,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \ python3-venv \ sqlite3 \ libpq-dev \ + libicu-dev \ + pkg-config \ xmlsec1 # Install rust and ensure it's in the PATH diff --git a/poetry.lock b/poetry.lock index cac22e2ef0..ccda8a23fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -837,6 +837,14 @@ category = "dev" optional = false python-versions = ">=3.5" +[[package]] +name = "pyicu" +version = "2.10.2" +description = "Python extension wrapping the ICU C++ API" +category = "main" +optional = true +python-versions = "*" + [[package]] name = "pyjwt" version = "2.4.0" @@ -1622,7 +1630,7 @@ docs = ["Sphinx", "repoze.sphinx.autointerface"] test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"] [extras] -all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler"] +all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler", "pyicu"] cache-memory = ["Pympler"] jwt = ["authlib"] matrix-synapse-ldap3 = ["matrix-synapse-ldap3"] @@ -1635,11 +1643,12 @@ sentry = ["sentry-sdk"] systemd = ["systemd-python"] test = ["parameterized", "idna"] url-preview = ["lxml"] +user-search = ["pyicu"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263" +content-hash = "f20007013f33bc35a01e412c48adc62a936030f3074e06286674c5ad7f44d300" [metadata.files] attrs = [ @@ -2427,6 +2436,9 @@ pygments = [ {file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"}, {file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"}, ] +pyicu = [ + {file = "PyICU-2.10.2.tar.gz", hash = "sha256:0c3309eea7fab6857507ace62403515b60fe096cbfb4f90d14f55ff75c5441c1"}, +] pyjwt = [ {file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"}, {file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"}, diff --git a/pyproject.toml b/pyproject.toml index df59fa0562..bb383683cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,6 +208,7 @@ hiredis = { version = "*", optional = true } Pympler = { version = "*", optional = true } parameterized = { version = ">=0.7.4", optional = true } idna = { version = ">=2.5", optional = true } +pyicu = { version = ">=2.10.2", optional = true } [tool.poetry.extras] # NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified @@ -230,6 +231,10 @@ redis = ["txredisapi", "hiredis"] # Required to use experimental `caches.track_memory_usage` config option. cache-memory = ["pympler"] test = ["parameterized", "idna"] +# Allows for better search for international characters in the user directory. This +# requires libicu's development headers installed on the system (e.g. libicu-dev on +# Debian-based distributions). +user-search = ["pyicu"] # The duplication here is awful. I hate hate hate hate hate it. However, for now I want # to ensure you can still `pip install matrix-synapse[all]` like today. Two motivations: @@ -261,6 +266,8 @@ all = [ "txredisapi", "hiredis", # cache-memory "pympler", + # improved user search + "pyicu", # omitted: # - test: it's useful to have this separate from dev deps in the olddeps job # - systemd: this is a system-based requirement diff --git a/stubs/icu.pyi b/stubs/icu.pyi new file mode 100644 index 0000000000..efeda7938a --- /dev/null +++ b/stubs/icu.pyi @@ -0,0 +1,25 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stub for PyICU. + +class Locale: + @staticmethod + def getDefault() -> Locale: ... + +class BreakIterator: + @staticmethod + def createWordInstance(locale: Locale) -> BreakIterator: ... + def setText(self, text: str) -> None: ... + def nextBoundary(self) -> int: ... diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af9952f513..14ef5b040d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,14 @@ from typing import ( cast, ) +try: + # Figure out if ICU support is available for searching users. + import icu + + USE_ICU = True +except ModuleNotFoundError: + USE_ICU = False + from typing_extensions import TypedDict from synapse.api.errors import StoreError @@ -900,7 +908,7 @@ def _parse_query_sqlite(search_term: str) -> str: """ # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + results = _parse_words(search_term) return " & ".join("(%s* OR %s)" % (result, result) for result in results) @@ -910,12 +918,63 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + results = _parse_words(search_term) both = " & ".join("(%s:* | %s)" % (result, result) for result in results) exact = " & ".join("%s" % (result,) for result in results) prefix = " & ".join("%s:*" % (result,) for result in results) return both, exact, prefix + + +def _parse_words(search_term: str) -> List[str]: + """Split the provided search string into a list of its words. + + If support for ICU (International Components for Unicode) is available, use it. + Otherwise, fall back to using a regex to detect word boundaries. This latter + solution works well enough for most latin-based languages, but doesn't work as well + with other languages. + + Args: + search_term: The search string. + + Returns: + A list of the words in the search string. + """ + if USE_ICU: + return _parse_words_with_icu(search_term) + + return re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + +def _parse_words_with_icu(search_term: str) -> List[str]: + """Break down the provided search string into its individual words using ICU + (International Components for Unicode). + + Args: + search_term: The search string. + + Returns: + A list of the words in the search string. + """ + results = [] + breaker = icu.BreakIterator.createWordInstance(icu.Locale.getDefault()) + breaker.setText(search_term) + i = 0 + while True: + j = breaker.nextBoundary() + if j < 0: + break + + result = search_term[i:j] + + # libicu considers spaces and punctuation between words as words, but we don't + # want to include those in results as they would result in syntax errors in SQL + # queries (e.g. "foo bar" would result in the search query including "foo & & + # bar"). + if len(re.findall(r"([\w\-]+)", result, re.UNICODE)): + results.append(result) + + i = j + + return results diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 88c7d5fec0..3ba896ecf3 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Any, Dict, Set, Tuple from unittest import mock from unittest.mock import Mock, patch @@ -30,6 +31,12 @@ from synapse.util import Clock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config +try: + import icu +except ImportError: + icu = None # type: ignore + + ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" @@ -467,3 +474,39 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): r["results"][0], {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + + +class UserDirectoryICUTestCase(HomeserverTestCase): + if not icu: + skip = "Requires PyICU" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.user_dir_helper = GetUserDirectoryTables(self.store) + + def test_icu_word_boundary(self) -> None: + """Tests that we correctly detect word boundaries when ICU (International + Components for Unicode) support is available. + """ + + display_name = "Gáo" + + # This word is not broken down correctly by Python's regular expressions, + # likely because á is actually a lowercase a followed by a U+0301 combining + # acute accent. This is specifically something that ICU support fixes. + matches = re.findall(r"([\w\-]+)", display_name, re.UNICODE) + self.assertEqual(len(matches), 2) + + self.get_success( + self.store.update_profile_in_user_dir(ALICE, display_name, None) + ) + self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE,))) + + # Check that searching for this user yields the correct result. + r = self.get_success(self.store.search_user_dir(BOB, display_name, 10)) + self.assertFalse(r["limited"]) + self.assertEqual(len(r["results"]), 1) + self.assertDictEqual( + r["results"][0], + {"user_id": ALICE, "display_name": display_name, "avatar_url": None}, + ) -- cgit 1.5.1 From 74b89c27613a34ec9b291ad3066db7ce0adff1db Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 12 Dec 2022 13:55:23 +0000 Subject: Revert the deletion of stale devices due to performance issues. (#14662) --- changelog.d/14595.misc | 1 - changelog.d/14649.misc | 1 - changelog.d/14662.removal | 1 + synapse/handlers/device.py | 33 +----------- synapse/storage/databases/main/devices.py | 84 +------------------------------ tests/handlers/test_device.py | 33 +----------- tests/storage/test_client_ips.py | 4 +- 7 files changed, 5 insertions(+), 152 deletions(-) delete mode 100644 changelog.d/14595.misc delete mode 100644 changelog.d/14649.misc create mode 100644 changelog.d/14662.removal (limited to 'tests') diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc deleted file mode 100644 index f9bfc581ad..0000000000 --- a/changelog.d/14595.misc +++ /dev/null @@ -1 +0,0 @@ -Prune user's old devices on login if they have too many. diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc deleted file mode 100644 index f9bfc581ad..0000000000 --- a/changelog.d/14649.misc +++ /dev/null @@ -1 +0,0 @@ -Prune user's old devices on login if they have too many. diff --git a/changelog.d/14662.removal b/changelog.d/14662.removal new file mode 100644 index 0000000000..19a387bbb4 --- /dev/null +++ b/changelog.d/14662.removal @@ -0,0 +1 @@ +(remove from changelog: unreleased) Revert the deletion of stale devices due to performance issues. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c935c7be90..d4750a32e6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,7 +52,6 @@ from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable -from synapse.util.iterutils import batch_iter from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -422,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) - # Prune the user's device list if they already have a lot of devices. - await self._prune_too_many_devices(user_id) - if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -456,33 +452,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _prune_too_many_devices(self, user_id: str) -> None: - """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id, 100) - if not device_ids: - return - - logger.info("Pruning %d old devices for user %s", len(device_ids), user_id) - - # We don't want to block and try and delete tonnes of devices at once, - # so we cap the number of devices we delete synchronously. - first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] - await self.delete_devices(user_id, first_batch) - - if not remaining_device_ids: - return - - # Now spawn a background loop that deletes the rest. - async def _prune_too_many_devices_loop() -> None: - for batch in batch_iter(remaining_device_ids, 10): - await self.delete_devices(user_id, batch) - - await self.clock.sleep(1) - - run_as_background_process( - "_prune_too_many_devices_loop", _prune_too_many_devices_loop - ) - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -512,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 95d4c0622d..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,77 +1569,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user( - self, user_id: str, limit: int - ) -> List[str]: - """Check if the user has a lot of devices, and if so return the set of - devices we can prune. - - This does *not* return hidden devices or devices with E2E keys. - - Returns at most `limit` number of devices, ordered by last seen. - """ - - num_devices = await self.db_pool.simple_select_one_onecol( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcol="COALESCE(COUNT(*), 0)", - desc="count_devices", - ) - - # We let users have up to ten devices without pruning. - if num_devices <= 10: - return [] - - # We prune everything older than N days. - max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 - - if num_devices > 50: - # If the user has more than 50 devices, then we chose a last seen - # that ensures we keep at most 50 devices. - sql = """ - SELECT last_seen FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen IS NOT NULL - AND key_json IS NULL - ORDER BY last_seen DESC - LIMIT 1 - OFFSET 50 - """ - - rows = await self.db_pool.execute( - "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) - ) - if rows: - max_last_seen = max(rows[0][0], max_last_seen) - - # Now fetch the devices to delete. - sql = """ - SELECT device_id FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen < ? - AND key_json IS NULL - ORDER BY last_seen - LIMIT ? - """ - - def check_too_many_devices_for_user_txn( - txn: LoggingTransaction, - ) -> List[str]: - txn.execute(sql, (user_id, max_last_seen, limit)) - return [device_id for device_id, in txn] - - return await self.db_pool.runInteraction( - "check_too_many_devices_for_user", - check_too_many_devices_for_user_txn, - ) - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1698,7 +1627,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, - "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1744,15 +1672,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @cached(max_entries=0) - async def delete_device(self, user_id: str, device_id: str) -> None: - raise NotImplementedError() - - # Note: sometimes deleting rows out of `device_inbox` can take a long time, - # so we use a cache so that we deduplicate in flight requests to delete - # devices. - @cachedList(cached_method_name="delete_device", list_name="device_ids") - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: @@ -1789,8 +1709,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - return {} - async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index e51cac9b33..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,8 +20,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler -from synapse.rest import admin -from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.util import Clock @@ -32,12 +30,6 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - admin.register_servlets, - account.register_servlets, - ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() @@ -123,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": 1000000, + "last_seen_ts": None, }, device_map["xyz"], ) @@ -237,29 +229,6 @@ class DeviceTestCase(unittest.HomeserverTestCase): NotFoundError, ) - def test_login_delete_old_devices(self) -> None: - """Delete old devices if the user already has too many.""" - - user_id = self.register_user("user", "pass") - - # Create a bunch of devices - for _ in range(50): - self.login("user", "pass") - self.reactor.advance(1) - - # Advance the clock for ages (as we only delete old devices) - self.reactor.advance(60 * 60 * 24 * 300) - - # Log in again to start the pruning - self.login("user", "pass") - - # Give the background job time to do its thing - self.reactor.pump([1.0] * 100) - - # We should now only have the most recent device. - devices = self.get_success(self.handler.get_devices_by_user(user_id)) - self.assertEqual(len(devices), 1) - def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 81e4e596e4..7f7f4ef892 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -170,8 +170,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) - last_seen = self.clock.time_msec() - if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -192,7 +190,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": last_seen, + "last_seen": None, }, ], ) -- cgit 1.5.1 From b5b5f6608462a988b05502a3b70b6a57ca3846d2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:19:30 +0000 Subject: Move `StateFilter` to `synapse.types` (#14668) * Move `StateFilter` to `synapse.types` * Changelog --- changelog.d/14668.misc | 1 + synapse/events/builder.py | 2 +- synapse/events/snapshot.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 2 +- synapse/rest/client/room.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/controllers/persist_events.py | 2 +- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/state.py | 2 +- synapse/storage/databases/state/bg_updates.py | 2 +- synapse/storage/databases/state/store.py | 2 +- synapse/storage/state.py | 567 ---------------- synapse/types.py | 928 -------------------------- synapse/types/__init__.py | 928 ++++++++++++++++++++++++++ synapse/types/state.py | 567 ++++++++++++++++ synapse/visibility.py | 2 +- tests/storage/test_state.py | 2 +- 29 files changed, 1520 insertions(+), 1519 deletions(-) create mode 100644 changelog.d/14668.misc delete mode 100644 synapse/storage/state.py delete mode 100644 synapse/types.py create mode 100644 synapse/types/__init__.py create mode 100644 synapse/types/state.py (limited to 'tests') diff --git a/changelog.d/14668.misc b/changelog.d/14668.misc new file mode 100644 index 0000000000..5269d8a97d --- /dev/null +++ b/changelog.d/14668.misc @@ -0,0 +1 @@ +Move `StateFilter` to `synapse.types`. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index d62906043f..94dd1298e1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -28,8 +28,8 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import EventID, JsonDict +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.stringutils import random_string diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 1c0e96bec7..6eaef8b57a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore - from synapse.storage.state import StateFilter + from synapse.types.state import StateFilter @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3398fcaf7d..b2784d7333 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -70,8 +70,8 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f7223b03c3..d2facdab60 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -75,7 +75,6 @@ from synapse.replication.http.federation import ( from synapse.state import StateResolutionStore from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, @@ -83,6 +82,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5cbe89f4fd..d6e90ef259 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,7 +59,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, PersistedEventPosition, @@ -70,6 +69,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index c572508a02..8c8ff18a1a 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6307fa9c5d..c611efb760 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -46,8 +46,8 @@ from synapse.replication.http.register import ( ReplicationRegisterServlet, ) from synapse.spam_checker_api import RegistrationBehaviour -from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester +from synapse.types.state import StateFilter if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6dcfd86fdf..f81241c2b3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -62,7 +62,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -77,6 +76,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import stringutils from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6ad2b38b8f..0c39e852a1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -34,7 +34,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, Requester, @@ -45,6 +44,7 @@ from synapse.types import ( create_requester, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index bcab98c6d5..33115ce488 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.storage.state import StateFilter from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client if TYPE_CHECKING: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dace9b606f..7d6a653747 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -49,7 +49,6 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.storage.state import StateFilter from synapse.types import ( DeviceListUpdates, JsonDict, @@ -61,6 +60,7 @@ from synapse.types import ( StreamToken, UserID, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 96a661177a..0092a03c59 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,7 +111,6 @@ from synapse.storage.background_updates import ( ) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.types import ( DomainSpecificString, JsonDict, @@ -124,6 +123,7 @@ from synapse.types import ( UserProfile, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import CachedFunction, cached diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9ed35d8461..36e5b327ef 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -35,8 +35,8 @@ from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership -from synapse.storage.state import StateFilter from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index c2575ba3d9..93b255ced5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -37,8 +37,8 @@ from synapse.push.push_types import ( TemplateVars, ) from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 747e6fda83..e957aa28ca 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,9 +34,9 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, RoomID, UserID, create_requester +from synapse.types.state import StateFilter from synapse.util import json_decoder if TYPE_CHECKING: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 514eb6afc8..790614d721 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -55,9 +55,9 @@ from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name, random_string diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 833ffec3de..ee5469d5a8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,8 +44,8 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import StateMap +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 33ffef521b..f1d2c71c91 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -58,13 +58,13 @@ from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2b31ce54bb..26d79c6e62 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -31,12 +31,12 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.cancellation import cancellable if TYPE_CHECKING: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index af7bebee80..c801a93b5b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -33,8 +33,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 4a4ad0f492..d743282f13 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -22,8 +22,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.state import StateFilter from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string if TYPE_CHECKING: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f8cfcaca83..1a7232b276 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -25,10 +25,10 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.types.state import StateFilter from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.cancellation import cancellable diff --git a/synapse/storage/state.py b/synapse/storage/state.py deleted file mode 100644 index 0004d955b4..0000000000 --- a/synapse/storage/state.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import ( - TYPE_CHECKING, - Callable, - Collection, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - TypeVar, -) - -import attr -from frozendict import frozendict - -from synapse.api.constants import EventTypes -from synapse.types import MutableStateMap, StateKey, StateMap - -if TYPE_CHECKING: - from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - - -logger = logging.getLogger(__name__) - -# Used for generic functions below -T = TypeVar("T") - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StateFilter: - """A filter used when querying for state. - - Attributes: - types: Map from type to set of state keys (or None). This specifies - which state_keys for the given type to fetch from the DB. If None - then all events with that type are fetched. If the set is empty - then no events with that type are fetched. - include_others: Whether to fetch events with types that do not - appear in `types`. - """ - - types: "frozendict[str, Optional[FrozenSet[str]]]" - include_others: bool = False - - def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary - if self.include_others: - # this is needed to work around the fact that StateFilter is frozen - object.__setattr__( - self, - "types", - frozendict({k: v for k, v in self.types.items() if v is not None}), - ) - - @staticmethod - def all() -> "StateFilter": - """Returns a filter that fetches everything. - - Returns: - The state filter. - """ - return _ALL_STATE_FILTER - - @staticmethod - def none() -> "StateFilter": - """Returns a filter that fetches nothing. - - Returns: - The new state filter. - """ - return _NONE_STATE_FILTER - - @staticmethod - def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": - """Creates a filter that only fetches the given types - - Args: - types: A list of type and state keys to fetch. A state_key of None - fetches everything for that type - - Returns: - The new state filter. - """ - type_dict: Dict[str, Optional[Set[str]]] = {} - for typ, s in types: - if typ in type_dict: - if type_dict[typ] is None: - continue - - if s is None: - type_dict[typ] = None - continue - - type_dict.setdefault(typ, set()).add(s) # type: ignore - - return StateFilter( - types=frozendict( - (k, frozenset(v) if v is not None else None) - for k, v in type_dict.items() - ) - ) - - @staticmethod - def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": - """Creates a filter that returns all non-member events, plus the member - events for the given users - - Args: - members: Set of user IDs - - Returns: - The new state filter - """ - return StateFilter( - types=frozendict({EventTypes.Member: frozenset(members)}), - include_others=True, - ) - - @staticmethod - def freeze( - types: Mapping[str, Optional[Collection[str]]], include_others: bool - ) -> "StateFilter": - """ - Returns a (frozen) StateFilter with the same contents as the parameters - specified here, which can be made of mutable types. - """ - types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} - for state_types, state_keys in types.items(): - if state_keys is not None: - types_with_frozen_values[state_types] = frozenset(state_keys) - else: - types_with_frozen_values[state_types] = None - - return StateFilter( - frozendict(types_with_frozen_values), include_others=include_others - ) - - def return_expanded(self) -> "StateFilter": - """Creates a new StateFilter where type wild cards have been removed - (except for memberships). The returned filter is a superset of the - current one, i.e. anything that passes the current filter will pass - the returned filter. - - This helps the caching as the DictionaryCache knows if it has *all* the - state, but does not know if it has all of the keys of a particular type, - which makes wildcard lookups expensive unless we have a complete cache. - Hence, if we are doing a wildcard lookup, populate the cache fully so - that we can do an efficient lookup next time. - - Note that since we have two caches, one for membership events and one for - other events, we can be a bit more clever than simply returning - `StateFilter.all()` if `has_wildcards()` is True. - - We return a StateFilter where: - 1. the list of membership events to return is the same - 2. if there is a wildcard that matches non-member events we - return all non-member events - - Returns: - The new state filter. - """ - - if self.is_full(): - # If we're going to return everything then there's nothing to do - return self - - if not self.has_wildcards(): - # If there are no wild cards, there's nothing to do - return self - - if EventTypes.Member in self.types: - get_all_members = self.types[EventTypes.Member] is None - else: - get_all_members = self.include_others - - has_non_member_wildcard = self.include_others or any( - state_keys is None - for t, state_keys in self.types.items() - if t != EventTypes.Member - ) - - if not has_non_member_wildcard: - # If there are no non-member wild cards we can just return ourselves - return self - - if get_all_members: - # We want to return everything. - return StateFilter.all() - elif EventTypes.Member in self.types: - # We want to return all non-members, but only particular - # memberships - return StateFilter( - types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), - include_others=True, - ) - else: - # We want to return all non-members - return _ALL_NON_MEMBER_STATE_FILTER - - def make_sql_filter_clause(self) -> Tuple[str, List[str]]: - """Converts the filter to an SQL clause. - - For example: - - f = StateFilter.from_types([("m.room.create", "")]) - clause, args = f.make_sql_filter_clause() - clause == "(type = ? AND state_key = ?)" - args == ['m.room.create', ''] - - - Returns: - The SQL string (may be empty) and arguments. An empty SQL string is - returned when the filter matches everything (i.e. is "full"). - """ - - where_clause = "" - where_args: List[str] = [] - - if self.is_full(): - return where_clause, where_args - - if not self.include_others and not self.types: - # i.e. this is an empty filter, so we need to return a clause that - # will match nothing - return "1 = 2", [] - - # First we build up a lost of clauses for each type/state_key combo - clauses = [] - for etype, state_keys in self.types.items(): - if state_keys is None: - clauses.append("(type = ?)") - where_args.append(etype) - continue - - for state_key in state_keys: - clauses.append("(type = ? AND state_key = ?)") - where_args.extend((etype, state_key)) - - # This will match anything that appears in `self.types` - where_clause = " OR ".join(clauses) - - # If we want to include stuff that's not in the types dict then we add - # a `OR type NOT IN (...)` clause to the end. - if self.include_others: - if where_clause: - where_clause += " OR " - - where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) - where_args.extend(self.types) - - return where_clause, where_args - - def max_entries_returned(self) -> Optional[int]: - """Returns the maximum number of entries this filter will return if - known, otherwise returns None. - - For example a simple state filter asking for `("m.room.create", "")` - will return 1, whereas the default state filter will return None. - - This is used to bail out early if the right number of entries have been - fetched. - """ - if self.has_wildcards(): - return None - - return len(self.concrete_types()) - - def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: - """Returns the state filtered with by this StateFilter. - - Args: - state: The state map to filter - - Returns: - The filtered state map. - This is a copy, so it's safe to mutate. - """ - if self.is_full(): - return dict(state_dict) - - filtered_state = {} - for k, v in state_dict.items(): - typ, state_key = k - if typ in self.types: - state_keys = self.types[typ] - if state_keys is None or state_key in state_keys: - filtered_state[k] = v - elif self.include_others: - filtered_state[k] = v - - return filtered_state - - def is_full(self) -> bool: - """Whether this filter fetches everything or not - - Returns: - True if the filter fetches everything. - """ - return self.include_others and not self.types - - def has_wildcards(self) -> bool: - """Whether the filter includes wildcards or is attempting to fetch - specific state. - - Returns: - True if the filter includes wildcards. - """ - - return self.include_others or any( - state_keys is None for state_keys in self.types.values() - ) - - def concrete_types(self) -> List[Tuple[str, str]]: - """Returns a list of concrete type/state_keys (i.e. not None) that - will be fetched. This will be a complete list if `has_wildcards` - returns False, but otherwise will be a subset (or even empty). - - Returns: - A list of type/state_keys tuples. - """ - return [ - (t, s) - for t, state_keys in self.types.items() - if state_keys is not None - for s in state_keys - ] - - def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: - """Return the filter split into two: one which assumes it's exclusively - matching against member state, and one which assumes it's matching - against non member state. - - This is useful due to the returned filters giving correct results for - `is_full()`, `has_wildcards()`, etc, when operating against maps that - either exclusively contain member events or only contain non-member - events. (Which is the case when dealing with the member vs non-member - state caches). - - Returns: - The member and non member filters - """ - - if EventTypes.Member in self.types: - state_keys = self.types[EventTypes.Member] - if state_keys is None: - member_filter = StateFilter.all() - else: - member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) - elif self.include_others: - member_filter = StateFilter.all() - else: - member_filter = StateFilter.none() - - non_member_filter = StateFilter( - types=frozendict( - {k: v for k, v in self.types.items() if k != EventTypes.Member} - ), - include_others=self.include_others, - ) - - return member_filter, non_member_filter - - def _decompose_into_four_parts( - self, - ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: - """ - Decomposes this state filter into 4 constituent parts, which can be - thought of as this: - all? - minus_wildcards + plus_wildcards + plus_state_keys - - where - * all represents ALL state - * minus_wildcards represents entire state types to remove - * plus_wildcards represents entire state types to add - * plus_state_keys represents individual state keys to add - - See `recompose_from_four_parts` for the other direction of this - correspondence. - """ - is_all = self.include_others - excluded_types: Set[str] = {t for t in self.types if is_all} - wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} - concrete_keys: Set[StateKey] = set(self.concrete_types()) - - return (is_all, excluded_types), (wildcard_types, concrete_keys) - - @staticmethod - def _recompose_from_four_parts( - all_part: bool, - minus_wildcards: Set[str], - plus_wildcards: Set[str], - plus_state_keys: Set[StateKey], - ) -> "StateFilter": - """ - Recomposes a state filter from 4 parts. - - See `decompose_into_four_parts` (the other direction of this - correspondence) for descriptions on each of the parts. - """ - - # {state type -> set of state keys OR None for wildcard} - # (The same structure as that of a StateFilter.) - new_types: Dict[str, Optional[Set[str]]] = {} - - # if we start with all, insert the excluded statetypes as empty sets - # to prevent them from being included - if all_part: - new_types.update({state_type: set() for state_type in minus_wildcards}) - - # insert the plus wildcards - new_types.update({state_type: None for state_type in plus_wildcards}) - - # insert the specific state keys - for state_type, state_key in plus_state_keys: - if state_type in new_types: - entry = new_types[state_type] - if entry is not None: - entry.add(state_key) - elif not all_part: - # don't insert if the entire type is already included by - # include_others as this would actually shrink the state allowed - # by this filter. - new_types[state_type] = {state_key} - - return StateFilter.freeze(new_types, include_others=all_part) - - def approx_difference(self, other: "StateFilter") -> "StateFilter": - """ - Returns a state filter which represents `self - other`. - - This is useful for determining what state remains to be pulled out of the - database if we want the state included by `self` but already have the state - included by `other`. - - The returned state filter - - MUST include all state events that are included by this filter (`self`) - unless they are included by `other`; - - MUST NOT include state events not included by this filter (`self`); and - - MAY be an over-approximation: the returned state filter - MAY additionally include some state events from `other`. - - This implementation attempts to return the narrowest such state filter. - In the case that `self` contains wildcards for state types where - `other` contains specific state keys, an approximation must be made: - the returned state filter keeps the wildcard, as state filters are not - able to express 'all state keys except some given examples'. - e.g. - StateFilter(m.room.member -> None (wildcard)) - minus - StateFilter(m.room.member -> {'@wombat:example.org'}) - is approximated as - StateFilter(m.room.member -> None (wildcard)) - """ - - # We first transform self and other into an alternative representation: - # - whether or not they include all events to begin with ('all') - # - if so, which event types are excluded? ('excludes') - # - which entire event types to include ('wildcards') - # - which concrete state keys to include ('concrete state keys') - (self_all, self_excludes), ( - self_wildcards, - self_concrete_keys, - ) = self._decompose_into_four_parts() - (other_all, other_excludes), ( - other_wildcards, - other_concrete_keys, - ) = other._decompose_into_four_parts() - - # Start with an estimate of the difference based on self - new_all = self_all - # Wildcards from the other can be added to the exclusion filter - new_excludes = self_excludes | other_wildcards - # We remove wildcards that appeared as wildcards in the other - new_wildcards = self_wildcards - other_wildcards - # We filter out the concrete state keys that appear in the other - # as wildcards or concrete state keys. - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in self_concrete_keys - if state_type not in other_wildcards - } - other_concrete_keys - - if other_all: - if self_all: - # If self starts with all, then we add as wildcards any - # types which appear in the other's exclusion filter (but - # aren't in the self exclusion filter). This is as the other - # filter will return everything BUT the types in its exclusion, so - # we need to add those excluded types that also match the self - # filter as wildcard types in the new filter. - new_wildcards |= other_excludes.difference(self_excludes) - - # If other is an `include_others` then the difference isn't. - new_all = False - # (We have no need for excludes when we don't start with all, as there - # is nothing to exclude.) - new_excludes = set() - - # We also filter out all state types that aren't in the exclusion - # list of the other. - new_wildcards &= other_excludes - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in new_concrete_keys - if state_type in other_excludes - } - - # Transform our newly-constructed state filter from the alternative - # representation back into the normal StateFilter representation. - return StateFilter._recompose_from_four_parts( - new_all, new_excludes, new_wildcards, new_concrete_keys - ) - - def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: - """Check if we need to wait for full state to complete to calculate this state - - If we have a state filter which is completely satisfied even with partial - state, then we don't need to await_full_state before we can return it. - - Args: - is_mine_id: a callable which confirms if a given state_key matches a mxid - of a local user - """ - # if we haven't requested membership events, then it depends on the value of - # 'include_others' - if EventTypes.Member not in self.types: - return self.include_others - - # if we're looking for *all* membership events, then we have to wait - member_state_keys = self.types[EventTypes.Member] - if member_state_keys is None: - return True - - # otherwise, consider whose membership we are looking for. If it's entirely - # local users, then we don't need to wait. - for state_key in member_state_keys: - if not is_mine_id(state_key): - # remote user - return True - - # local users only - return False - - -_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) -_ALL_NON_MEMBER_STATE_FILTER = StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True -) -_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/types.py b/synapse/types.py deleted file mode 100644 index f2d436ddc3..0000000000 --- a/synapse/types.py +++ /dev/null @@ -1,928 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -import re -import string -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - List, - Mapping, - Match, - MutableMapping, - NoReturn, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) - -import attr -from frozendict import frozendict -from signedjson.key import decode_verify_key_bytes -from signedjson.types import VerifyKey -from typing_extensions import Final, TypedDict -from unpaddedbase64 import decode_base64 -from zope.interface import Interface - -from twisted.internet.defer import CancelledError -from twisted.internet.interfaces import ( - IReactorCore, - IReactorPluggableNameResolver, - IReactorSSL, - IReactorTCP, - IReactorThreads, - IReactorTime, -) - -from synapse.api.errors import Codes, SynapseError -from synapse.util.cancellation import cancellable -from synapse.util.stringutils import parse_and_validate_server_name - -if TYPE_CHECKING: - from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore, PurgeEventsStore - from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore - -# Define a state map type from type/state_key to T (usually an event ID or -# event) -T = TypeVar("T") -StateKey = Tuple[str, str] -StateMap = Mapping[StateKey, T] -MutableStateMap = MutableMapping[StateKey, T] - -# JSON types. These could be made stronger, but will do for now. -# A JSON-serialisable dict. -JsonDict = Dict[str, Any] -# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. -# Useful when you have a TypedDict which isn't going to be mutated and you don't want -# to cast to JsonDict everywhere. -JsonMapping = Mapping[str, Any] -# A JSON-serialisable object. -JsonSerializable = object - - -# Note that this seems to require inheriting *directly* from Interface in order -# for mypy-zope to realize it is an interface. -class ISynapseReactor( - IReactorTCP, - IReactorSSL, - IReactorPluggableNameResolver, - IReactorTime, - IReactorCore, - IReactorThreads, - Interface, -): - """The interfaces necessary for Synapse to function.""" - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class Requester: - """ - Represents the user making a request - - Attributes: - user: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request has been shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - """ - - user: "UserID" - access_token_id: Optional[int] - is_guest: bool - shadow_banned: bool - device_id: Optional[str] - app_service: Optional["ApplicationService"] - authenticated_entity: str - - def serialize(self) -> Dict[str, Any]: - """Converts self to a type that can be serialized as JSON, and then - deserialized by `deserialize` - - Returns: - dict - """ - return { - "user_id": self.user.to_string(), - "access_token_id": self.access_token_id, - "is_guest": self.is_guest, - "shadow_banned": self.shadow_banned, - "device_id": self.device_id, - "app_server_id": self.app_service.id if self.app_service else None, - "authenticated_entity": self.authenticated_entity, - } - - @staticmethod - def deserialize( - store: "ApplicationServiceWorkerStore", input: Dict[str, Any] - ) -> "Requester": - """Converts a dict that was produced by `serialize` back into a - Requester. - - Args: - store: Used to convert AS ID to AS object - input: A dict produced by `serialize` - - Returns: - Requester - """ - appservice = None - if input["app_server_id"]: - appservice = store.get_app_service_by_id(input["app_server_id"]) - - return Requester( - user=UserID.from_string(input["user_id"]), - access_token_id=input["access_token_id"], - is_guest=input["is_guest"], - shadow_banned=input["shadow_banned"], - device_id=input["device_id"], - app_service=appservice, - authenticated_entity=input["authenticated_entity"], - ) - - -def create_requester( - user_id: Union[str, "UserID"], - access_token_id: Optional[int] = None, - is_guest: bool = False, - shadow_banned: bool = False, - device_id: Optional[str] = None, - app_service: Optional["ApplicationService"] = None, - authenticated_entity: Optional[str] = None, -) -> Requester: - """ - Create a new ``Requester`` object - - Args: - user_id: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request is shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - - Returns: - Requester - """ - if not isinstance(user_id, UserID): - user_id = UserID.from_string(user_id) - - if authenticated_entity is None: - authenticated_entity = user_id.to_string() - - return Requester( - user_id, - access_token_id, - is_guest, - shadow_banned, - device_id, - app_service, - authenticated_entity, - ) - - -def get_domain_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[idx + 1 :] - - -def get_localpart_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[1:idx] - - -DS = TypeVar("DS", bound="DomainSpecificString") - - -@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) -class DomainSpecificString(metaclass=abc.ABCMeta): - """Common base class among ID/name strings that have a local part and a - domain name, prefixed with a sigil. - - Has the fields: - - 'localpart' : The local part of the name (without the leading sigil) - 'domain' : The domain part of the name - """ - - SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore - - localpart: str - domain: str - - # Because this is a frozen class, it is deeply immutable. - def __copy__(self: DS) -> DS: - return self - - def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: - return self - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0:1] != cls.SIGIL: - raise SynapseError( - 400, - "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - parts = s[1:].split(":", 1) - if len(parts) != 2: - raise SynapseError( - 400, - "Expected %s of the form '%slocalname:domain'" - % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - domain = parts[1] - # This code will need changing if we want to support multiple domain - # names on one HS - return cls(localpart=parts[0], domain=domain) - - def to_string(self) -> str: - """Return a string encoding the fields of the structure object.""" - return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) - - @classmethod - def is_valid(cls: Type[DS], s: str) -> bool: - """Parses the input string and attempts to ensure it is valid.""" - # TODO: this does not reject an empty localpart or an overly-long string. - # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar - try: - obj = cls.from_string(s) - # Apply additional validation to the domain. This is only done - # during is_valid (and not part of from_string) since it is - # possible for invalid data to exist in room-state, etc. - parse_and_validate_server_name(obj.domain) - return True - except Exception: - return False - - __repr__ = to_string - - -@attr.s(slots=True, frozen=True, repr=False) -class UserID(DomainSpecificString): - """Structure representing a user ID.""" - - SIGIL = "@" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomAlias(DomainSpecificString): - """Structure representing a room name.""" - - SIGIL = "#" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomID(DomainSpecificString): - """Structure representing a room id.""" - - SIGIL = "!" - - -@attr.s(slots=True, frozen=True, repr=False) -class EventID(DomainSpecificString): - """Structure representing an event id.""" - - SIGIL = "$" - - -mxid_localpart_allowed_characters = set( - "_-./=" + string.ascii_lowercase + string.digits -) - - -def contains_invalid_mxid_characters(localpart: str) -> bool: - """Check for characters not allowed in an mxid or groupid localpart - - Args: - localpart: the localpart to be checked - - Returns: - True if there are any naughty characters - """ - return any(c not in mxid_localpart_allowed_characters for c in localpart) - - -UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") - -# the following is a pattern which matches '=', and bytes which are not allowed in a mxid -# localpart. -# -# It works by: -# * building a string containing the allowed characters (excluding '=') -# * escaping every special character with a backslash (to stop '-' being interpreted as a -# range operator) -# * wrapping it in a '[^...]' regex -# * converting the whole lot to a 'bytes' sequence, so that we can use it to match -# bytes rather than strings -# -NON_MXID_CHARACTER_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( - "ascii" - ) -) - - -def map_username_to_mxid_localpart( - username: Union[str, bytes], case_sensitive: bool = False -) -> str: - """Map a username onto a string suitable for a MXID - - This follows the algorithm laid out at - https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. - - Args: - username: username to be mapped - case_sensitive: true if TEST and test should be mapped - onto different mxids - - Returns: - string suitable for a mxid localpart - """ - if not isinstance(username, bytes): - username = username.encode("utf-8") - - # first we sort out upper-case characters - if case_sensitive: - - def f1(m: Match[bytes]) -> bytes: - return b"_" + m.group().lower() - - username = UPPER_CASE_PATTERN.sub(f1, username) - else: - username = username.lower() - - # then we sort out non-ascii characters by converting to the hex equivalent. - def f2(m: Match[bytes]) -> bytes: - return b"=%02x" % (m.group()[0],) - - username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) - - # we also do the =-escaping to mxids starting with an underscore. - username = re.sub(b"^_", b"=5f", username) - - # we should now only have ascii bytes left, so can decode back to a string. - return username.decode("ascii") - - -@attr.s(frozen=True, slots=True, order=False) -class RoomStreamToken: - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] ▼ [1] ▼ [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream, events are ordered by - `stream_ordering` (when they arrived at the homeserver). - - When traversing historic events, events are first ordered by their `depth` - (`topological_ordering` in the event graph) and tie-broken by - `stream_ordering` (when the event arrived at the homeserver). - - If you're looking for more info about what a token with all of the - underscores means, ex. - `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring - for `StreamToken` below. - - --- - - Live tokens start with an "s" followed by the `stream_ordering` of the event - that comes before the position of the token. Said another way: - `stream_ordering` uniquely identifies a persisted event. The live token - means "the position just after the event identified by `stream_ordering`". - An example token is: - - s2633508 - - --- - - Historic tokens start with a "t" followed by the `depth` - (`topological_ordering` in the event graph) of the event that comes before - the position of the token, followed by "-", followed by the - `stream_ordering` of the event that comes before the position of the token. - An example token is: - - t426-2633508 - - --- - - There is also a third mode for live tokens where the token starts with "m", - which is sometimes used when using sharded event persisters. In this case - the events stream is considered to be a set of streams (one for each writer) - and the token encodes the vector clock of positions of each writer in their - respective streams. - - The format of the token in such case is an initial integer min position, - followed by the mapping of instance ID to position separated by '.' and '~': - - m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... - - The `min_pos` corresponds to the minimum position all writers have persisted - up to, and then only writers that are ahead of that position need to be - encoded. An example token is: - - m56~2.58~3.59 - - Which corresponds to a set of three (or more writers) where instances 2 and - 3 (these are instance IDs that can be looked up in the DB to fetch the more - commonly used instance names) are at positions 58 and 59 respectively, and - all other instances are at position 56. - - Note: The `RoomStreamToken` cannot have both a topological part and an - instance map. - - --- - - For caching purposes, `RoomStreamToken`s and by extension, all their - attributes, must be hashable. - """ - - topological: Optional[int] = attr.ib( - validator=attr.validators.optional(attr.validators.instance_of(int)), - ) - stream: int = attr.ib(validator=attr.validators.instance_of(int)) - - instance_map: "frozendict[str, int]" = attr.ib( - factory=frozendict, - validator=attr.validators.deep_mapping( - key_validator=attr.validators.instance_of(str), - value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(frozendict), - ), - ) - - def __attrs_post_init__(self) -> None: - """Validates that both `topological` and `instance_map` aren't set.""" - - if self.instance_map and self.topological: - raise ValueError( - "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." - ) - - @classmethod - async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - if string[0] == "t": - parts = string[1:].split("-", 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - if string[0] == "m": - parts = string[1:].split("~") - stream = int(parts[0]) - - instance_map = {} - for part in parts[1:]: - key, value = part.split(".") - instance_id = int(key) - pos = int(value) - - instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] - instance_map[instance_name] = pos - - return cls( - topological=None, - stream=stream, - instance_map=frozendict(instance_map), - ) - except CancelledError: - raise - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": - """Return a new token such that if an event is after both this token and - the other token, then its after the returned token too. - """ - - if self.topological or other.topological: - raise Exception("Can't advance topological tokens") - - max_stream = max(self.stream, other.stream) - - instance_map = { - instance: max( - self.instance_map.get(instance, self.stream), - other.instance_map.get(instance, other.stream), - ) - for instance in set(self.instance_map).union(other.instance_map) - } - - return RoomStreamToken(None, max_stream, frozendict(instance_map)) - - def as_historical_tuple(self) -> Tuple[int, int]: - """Returns a tuple of `(topological, stream)` for historical tokens. - - Raises if not an historical token (i.e. doesn't have a topological part). - """ - if self.topological is None: - raise Exception( - "Cannot call `RoomStreamToken.as_historical_tuple` on live token" - ) - - return self.topological, self.stream - - def get_stream_pos_for_instance(self, instance_name: str) -> int: - """Get the stream position that the given writer was at at this token. - - This only makes sense for "live" tokens that may have a vector clock - component, and so asserts that this is a "live" token. - """ - assert self.topological is None - - # If we don't have an entry for the instance we can assume that it was - # at `self.stream`. - return self.instance_map.get(instance_name, self.stream) - - def get_max_stream_pos(self) -> int: - """Get the maximum stream position referenced in this token. - - The corresponding "min" position is, by definition just `self.stream`. - - This is used to handle tokens that have non-empty `instance_map`, and so - reference stream positions after the `self.stream` position. - """ - return max(self.instance_map.values(), default=self.stream) - - async def to_string(self, store: "DataStore") -> str: - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - elif self.instance_map: - entries = [] - for name, pos in self.instance_map.items(): - instance_id = await store.get_id_for_instance(name) - entries.append(f"{instance_id}.{pos}") - - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" - else: - return "s%d" % (self.stream,) - - -class StreamKeyType: - """Known stream types. - - A stream is a list of entities ordered by an incrementing "stream token". - """ - - ROOM: Final = "room_key" - PRESENCE: Final = "presence_key" - TYPING: Final = "typing_key" - RECEIPT: Final = "receipt_key" - ACCOUNT_DATA: Final = "account_data_key" - PUSH_RULES: Final = "push_rules_key" - TO_DEVICE: Final = "to_device_key" - DEVICE_LIST: Final = "device_list_key" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StreamToken: - """A collection of keys joined together by underscores in the following - order and which represent the position in their respective streams. - - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` - 1. `room_key`: `s2633508` which is a `RoomStreamToken` - - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - - See the docstring for `RoomStreamToken` for more details. - 2. `presence_key`: `17` - 3. `typing_key`: `338` - 4. `receipt_key`: `6732159` - 5. `account_data_key`: `1082514` - 6. `push_rules_key`: `541479` - 7. `to_device_key`: `274711` - 8. `device_list_key`: `265584` - 9. `groups_key`: `1` (note that this key is now unused) - - You can see how many of these keys correspond to the various - fields in a "/sync" response: - ```json - { - "next_batch": "s12_4_0_1_1_1_1_4_1", - "presence": { - "events": [] - }, - "device_lists": { - "changed": [] - }, - "rooms": { - "join": { - "!QrZlfIDQLNLdZHqTnt:hs1": { - "timeline": { - "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1", - "limited": false - }, - "state": { - "events": [] - }, - "account_data": { - "events": [] - }, - "ephemeral": { - "events": [] - } - } - } - } - } - ``` - - --- - - For caching purposes, `StreamToken`s and by extension, all their attributes, - must be hashable. - """ - - room_key: RoomStreamToken = attr.ib( - validator=attr.validators.instance_of(RoomStreamToken) - ) - presence_key: int - typing_key: int - receipt_key: int - account_data_key: int - push_rules_key: int - to_device_key: int - device_list_key: int - # Note that the groups key is no longer used and may have bogus values. - groups_key: int - - _SEPARATOR = "_" - START: ClassVar["StreamToken"] - - @classmethod - @cancellable - async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": - """ - Creates a RoomStreamToken from its textual representation. - """ - try: - keys = string.split(cls._SEPARATOR) - while len(keys) < len(attr.fields(cls)): - # i.e. old token from before receipt_key - keys.append("0") - return cls( - await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) - ) - except CancelledError: - raise - except Exception: - raise SynapseError(400, "Invalid stream token") - - async def to_string(self, store: "DataStore") -> str: - return self._SEPARATOR.join( - [ - await self.room_key.to_string(store), - str(self.presence_key), - str(self.typing_key), - str(self.receipt_key), - str(self.account_data_key), - str(self.push_rules_key), - str(self.to_device_key), - str(self.device_list_key), - # Note that the groups key is no longer used, but it is still - # serialized so that there will not be confusion in the future - # if additional tokens are added. - str(self.groups_key), - ] - ) - - @property - def room_stream_id(self) -> int: - return self.room_key.stream - - def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": - """Advance the given key in the token to a new value if and only if the - new value is after the old value. - - :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. - """ - if key == StreamKeyType.ROOM: - new_token = self.copy_and_replace( - StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) - ) - return new_token - - new_token = self.copy_and_replace(key, new_value) - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) - - if old_id < new_id: - return new_token - else: - return self - - def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": - return attr.evolve(self, **{key: new_value}) - - -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class PersistedEventPosition: - """Position of a newly persisted event with instance that persisted it. - - This can be used to test whether the event is persisted before or after a - RoomStreamToken. - """ - - instance_name: str - stream: int - - def persisted_after(self, token: RoomStreamToken) -> bool: - return token.get_stream_pos_for_instance(self.instance_name) < self.stream - - def to_room_stream_token(self) -> RoomStreamToken: - """Converts the position to a room stream token such that events - persisted in the same room after this position will be after the - returned `RoomStreamToken`. - - Note: no guarantees are made about ordering w.r.t. events in other - rooms. - """ - # Doing the naive thing satisfies the desired properties described in - # the docstring. - return RoomStreamToken(None, self.stream) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThirdPartyInstanceID: - appservice_id: Optional[str] - network_id: Optional[str] - - # Deny iteration because it will bite you if you try to create a singleton - # set by: - # users = set(user) - def __iter__(self) -> NoReturn: - raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - - # Because this class is a frozen class, it is deeply immutable. - def __copy__(self) -> "ThirdPartyInstanceID": - return self - - def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": - return self - - @classmethod - def from_string(cls, s: str) -> "ThirdPartyInstanceID": - bits = s.split("|", 2) - if len(bits) != 2: - raise SynapseError(400, "Invalid ID %r" % (s,)) - - return cls(appservice_id=bits[0], network_id=bits[1]) - - def to_string(self) -> str: - return "%s|%s" % (self.appservice_id, self.network_id) - - __str__ = to_string - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ReadReceipt: - """Information about a read-receipt""" - - room_id: str - receipt_type: str - user_id: str - event_ids: List[str] - thread_id: Optional[str] - data: JsonDict - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceListUpdates: - """ - An object containing a diff of information regarding other users' device lists, intended for - a recipient to carry out device list tracking. - - Attributes: - changed: A set of users whose device lists have changed recently. - left: A set of users who the recipient no longer needs to track the device lists of. - Typically when those users no longer share any end-to-end encryption enabled rooms. - """ - - # We need to use a factory here, otherwise `set` is not evaluated at - # object instantiation, but instead at class definition instantiation. - # The latter happening only once, thus always giving you the same sets - # across multiple DeviceListUpdates instances. - # Also see: don't define mutable default arguments. - changed: Set[str] = attr.ib(factory=set) - left: Set[str] = attr.ib(factory=set) - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - -def get_verify_key_from_cross_signing_key( - key_info: Mapping[str, Any] -) -> Tuple[str, VerifyKey]: - """Get the key ID and signedjson verify key from a cross-signing key dict - - Args: - key_info: a cross-signing key dict, which must have a "keys" - property that has exactly one item in it - - Returns: - the key ID and verify key for the cross-signing key - """ - # make sure that a `keys` field is provided - if "keys" not in key_info: - raise ValueError("Invalid key") - keys = key_info["keys"] - # and that it contains exactly one key - if len(keys) == 1: - key_id, key_data = next(iter(keys.items())) - return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) - else: - raise ValueError("Invalid key") - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. - - Attributes: - user_id: ID of the user. - appservice_id: Application service ID that created this user. - consent_server_notice_sent: Version of policy documents the user has been sent. - consent_version: Version of policy documents the user has consented to. - creation_ts: Creation timestamp of the user. - is_admin: True if the user is an admin. - is_deactivated: True if the user has been deactivated. - is_guest: True if the user is a guest user. - is_shadow_banned: True if the user has been shadow-banned. - user_type: User type (None for normal user, 'support' and 'bot' other options). - """ - - user_id: UserID - appservice_id: Optional[int] - consent_server_notice_sent: Optional[str] - consent_version: Optional[str] - user_type: Optional[str] - creation_ts: int - is_admin: bool - is_deactivated: bool - is_guest: bool - is_shadow_banned: bool - - -class UserProfile(TypedDict): - user_id: str - display_name: Optional[str] - avatar_url: Optional[str] - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class RetentionPolicy: - min_lifetime: Optional[int] = None - max_lifetime: Optional[int] = None diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py new file mode 100644 index 0000000000..f2d436ddc3 --- /dev/null +++ b/synapse/types/__init__.py @@ -0,0 +1,928 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import re +import string +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Mapping, + Match, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import attr +from frozendict import frozendict +from signedjson.key import decode_verify_key_bytes +from signedjson.types import VerifyKey +from typing_extensions import Final, TypedDict +from unpaddedbase64 import decode_base64 +from zope.interface import Interface + +from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import ( + IReactorCore, + IReactorPluggableNameResolver, + IReactorSSL, + IReactorTCP, + IReactorThreads, + IReactorTime, +) + +from synapse.api.errors import Codes, SynapseError +from synapse.util.cancellation import cancellable +from synapse.util.stringutils import parse_and_validate_server_name + +if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService + from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] + +# JSON types. These could be made stronger, but will do for now. +# A JSON-serialisable dict. +JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] +# A JSON-serialisable object. +JsonSerializable = object + + +# Note that this seems to require inheriting *directly* from Interface in order +# for mypy-zope to realize it is an interface. +class ISynapseReactor( + IReactorTCP, + IReactorSSL, + IReactorPluggableNameResolver, + IReactorTime, + IReactorCore, + IReactorThreads, + Interface, +): + """The interfaces necessary for Synapse to function.""" + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class Requester: + """ + Represents the user making a request + + Attributes: + user: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + """ + + user: "UserID" + access_token_id: Optional[int] + is_guest: bool + shadow_banned: bool + device_id: Optional[str] + app_service: Optional["ApplicationService"] + authenticated_entity: str + + def serialize(self) -> Dict[str, Any]: + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Returns: + dict + """ + return { + "user_id": self.user.to_string(), + "access_token_id": self.access_token_id, + "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, + "device_id": self.device_id, + "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, + } + + @staticmethod + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": + """Converts a dict that was produced by `serialize` back into a + Requester. + + Args: + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` + + Returns: + Requester + """ + appservice = None + if input["app_server_id"]: + appservice = store.get_app_service_by_id(input["app_server_id"]) + + return Requester( + user=UserID.from_string(input["user_id"]), + access_token_id=input["access_token_id"], + is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], + device_id=input["device_id"], + app_service=appservice, + authenticated_entity=input["authenticated_entity"], + ) + + +def create_requester( + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: bool = False, + shadow_banned: bool = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, +) -> Requester: + """ + Create a new ``Requester`` object + + Args: + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + + return Requester( + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, + ) + + +def get_domain_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[idx + 1 :] + + +def get_localpart_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] + + +DS = TypeVar("DS", bound="DomainSpecificString") + + +@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) +class DomainSpecificString(metaclass=abc.ABCMeta): + """Common base class among ID/name strings that have a local part and a + domain name, prefixed with a sigil. + + Has the fields: + + 'localpart' : The local part of the name (without the leading sigil) + 'domain' : The domain part of the name + """ + + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore + + localpart: str + domain: str + + # Because this is a frozen class, it is deeply immutable. + def __copy__(self: DS) -> DS: + return self + + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: + return self + + @classmethod + def from_string(cls: Type[DS], s: str) -> DS: + """Parse the string given by 's' into a structure object.""" + if len(s) < 1 or s[0:1] != cls.SIGIL: + raise SynapseError( + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + parts = s[1:].split(":", 1) + if len(parts) != 2: + raise SynapseError( + 400, + "Expected %s of the form '%slocalname:domain'" + % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + domain = parts[1] + # This code will need changing if we want to support multiple domain + # names on one HS + return cls(localpart=parts[0], domain=domain) + + def to_string(self) -> str: + """Return a string encoding the fields of the structure object.""" + return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + + @classmethod + def is_valid(cls: Type[DS], s: str) -> bool: + """Parses the input string and attempts to ensure it is valid.""" + # TODO: this does not reject an empty localpart or an overly-long string. + # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar + try: + obj = cls.from_string(s) + # Apply additional validation to the domain. This is only done + # during is_valid (and not part of from_string) since it is + # possible for invalid data to exist in room-state, etc. + parse_and_validate_server_name(obj.domain) + return True + except Exception: + return False + + __repr__ = to_string + + +@attr.s(slots=True, frozen=True, repr=False) +class UserID(DomainSpecificString): + """Structure representing a user ID.""" + + SIGIL = "@" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomAlias(DomainSpecificString): + """Structure representing a room name.""" + + SIGIL = "#" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomID(DomainSpecificString): + """Structure representing a room id.""" + + SIGIL = "!" + + +@attr.s(slots=True, frozen=True, repr=False) +class EventID(DomainSpecificString): + """Structure representing an event id.""" + + SIGIL = "$" + + +mxid_localpart_allowed_characters = set( + "_-./=" + string.ascii_lowercase + string.digits +) + + +def contains_invalid_mxid_characters(localpart: str) -> bool: + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart: the localpart to be checked + + Returns: + True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + +UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") + +# the following is a pattern which matches '=', and bytes which are not allowed in a mxid +# localpart. +# +# It works by: +# * building a string containing the allowed characters (excluding '=') +# * escaping every special character with a backslash (to stop '-' being interpreted as a +# range operator) +# * wrapping it in a '[^...]' regex +# * converting the whole lot to a 'bytes' sequence, so that we can use it to match +# bytes rather than strings +# +NON_MXID_CHARACTER_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( + "ascii" + ) +) + + +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: + """Map a username onto a string suitable for a MXID + + This follows the algorithm laid out at + https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. + + Args: + username: username to be mapped + case_sensitive: true if TEST and test should be mapped + onto different mxids + + Returns: + string suitable for a mxid localpart + """ + if not isinstance(username, bytes): + username = username.encode("utf-8") + + # first we sort out upper-case characters + if case_sensitive: + + def f1(m: Match[bytes]) -> bytes: + return b"_" + m.group().lower() + + username = UPPER_CASE_PATTERN.sub(f1, username) + else: + username = username.lower() + + # then we sort out non-ascii characters by converting to the hex equivalent. + def f2(m: Match[bytes]) -> bytes: + return b"=%02x" % (m.group()[0],) + + username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) + + # we also do the =-escaping to mxids starting with an underscore. + username = re.sub(b"^_", b"=5f", username) + + # we should now only have ascii bytes left, so can decode back to a string. + return username.decode("ascii") + + +@attr.s(frozen=True, slots=True, order=False) +class RoomStreamToken: + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] ▼ [1] ▼ [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream, events are ordered by + `stream_ordering` (when they arrived at the homeserver). + + When traversing historic events, events are first ordered by their `depth` + (`topological_ordering` in the event graph) and tie-broken by + `stream_ordering` (when the event arrived at the homeserver). + + If you're looking for more info about what a token with all of the + underscores means, ex. + `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring + for `StreamToken` below. + + --- + + Live tokens start with an "s" followed by the `stream_ordering` of the event + that comes before the position of the token. Said another way: + `stream_ordering` uniquely identifies a persisted event. The live token + means "the position just after the event identified by `stream_ordering`". + An example token is: + + s2633508 + + --- + + Historic tokens start with a "t" followed by the `depth` + (`topological_ordering` in the event graph) of the event that comes before + the position of the token, followed by "-", followed by the + `stream_ordering` of the event that comes before the position of the token. + An example token is: + + t426-2633508 + + --- + + There is also a third mode for live tokens where the token starts with "m", + which is sometimes used when using sharded event persisters. In this case + the events stream is considered to be a set of streams (one for each writer) + and the token encodes the vector clock of positions of each writer in their + respective streams. + + The format of the token in such case is an initial integer min position, + followed by the mapping of instance ID to position separated by '.' and '~': + + m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... + + The `min_pos` corresponds to the minimum position all writers have persisted + up to, and then only writers that are ahead of that position need to be + encoded. An example token is: + + m56~2.58~3.59 + + Which corresponds to a set of three (or more writers) where instances 2 and + 3 (these are instance IDs that can be looked up in the DB to fetch the more + commonly used instance names) are at positions 58 and 59 respectively, and + all other instances are at position 56. + + Note: The `RoomStreamToken` cannot have both a topological part and an + instance map. + + --- + + For caching purposes, `RoomStreamToken`s and by extension, all their + attributes, must be hashable. + """ + + topological: Optional[int] = attr.ib( + validator=attr.validators.optional(attr.validators.instance_of(int)), + ) + stream: int = attr.ib(validator=attr.validators.instance_of(int)) + + instance_map: "frozendict[str, int]" = attr.ib( + factory=frozendict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(frozendict), + ), + ) + + def __attrs_post_init__(self) -> None: + """Validates that both `topological` and `instance_map` aren't set.""" + + if self.instance_map and self.topological: + raise ValueError( + "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." + ) + + @classmethod + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + if string[0] == "t": + parts = string[1:].split("-", 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] + instance_map[instance_name] = pos + + return cls( + topological=None, + stream=stream, + instance_map=frozendict(instance_map), + ) + except CancelledError: + raise + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, frozendict(instance_map)) + + def as_historical_tuple(self) -> Tuple[int, int]: + """Returns a tuple of `(topological, stream)` for historical tokens. + + Raises if not an historical token (i.e. doesn't have a topological part). + """ + if self.topological is None: + raise Exception( + "Cannot call `RoomStreamToken.as_historical_tuple` on live token" + ) + + return self.topological, self.stream + + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token. + + This only makes sense for "live" tokens that may have a vector clock + component, and so asserts that this is a "live" token. + """ + assert self.topological is None + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + + async def to_string(self, store: "DataStore") -> str: + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + instance_id = await store.get_id_for_instance(name) + entries.append(f"{instance_id}.{pos}") + + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + else: + return "s%d" % (self.stream,) + + +class StreamKeyType: + """Known stream types. + + A stream is a list of entities ordered by an incrementing "stream token". + """ + + ROOM: Final = "room_key" + PRESENCE: Final = "presence_key" + TYPING: Final = "typing_key" + RECEIPT: Final = "receipt_key" + ACCOUNT_DATA: Final = "account_data_key" + PUSH_RULES: Final = "push_rules_key" + TO_DEVICE: Final = "to_device_key" + DEVICE_LIST: Final = "device_list_key" + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StreamToken: + """A collection of keys joined together by underscores in the following + order and which represent the position in their respective streams. + + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + 1. `room_key`: `s2633508` which is a `RoomStreamToken` + - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` + - See the docstring for `RoomStreamToken` for more details. + 2. `presence_key`: `17` + 3. `typing_key`: `338` + 4. `receipt_key`: `6732159` + 5. `account_data_key`: `1082514` + 6. `push_rules_key`: `541479` + 7. `to_device_key`: `274711` + 8. `device_list_key`: `265584` + 9. `groups_key`: `1` (note that this key is now unused) + + You can see how many of these keys correspond to the various + fields in a "/sync" response: + ```json + { + "next_batch": "s12_4_0_1_1_1_1_4_1", + "presence": { + "events": [] + }, + "device_lists": { + "changed": [] + }, + "rooms": { + "join": { + "!QrZlfIDQLNLdZHqTnt:hs1": { + "timeline": { + "events": [], + "prev_batch": "s10_4_0_1_1_1_1_4_1", + "limited": false + }, + "state": { + "events": [] + }, + "account_data": { + "events": [] + }, + "ephemeral": { + "events": [] + } + } + } + } + } + ``` + + --- + + For caching purposes, `StreamToken`s and by extension, all their attributes, + must be hashable. + """ + + room_key: RoomStreamToken = attr.ib( + validator=attr.validators.instance_of(RoomStreamToken) + ) + presence_key: int + typing_key: int + receipt_key: int + account_data_key: int + push_rules_key: int + to_device_key: int + device_list_key: int + # Note that the groups key is no longer used and may have bogus values. + groups_key: int + + _SEPARATOR = "_" + START: ClassVar["StreamToken"] + + @classmethod + @cancellable + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + """ + Creates a RoomStreamToken from its textual representation. + """ + try: + keys = string.split(cls._SEPARATOR) + while len(keys) < len(attr.fields(cls)): + # i.e. old token from before receipt_key + keys.append("0") + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) + except CancelledError: + raise + except Exception: + raise SynapseError(400, "Invalid stream token") + + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. + str(self.groups_key), + ] + ) + + @property + def room_stream_id(self) -> int: + return self.room_key.stream + + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. + """ + if key == StreamKeyType.ROOM: + new_token = self.copy_and_replace( + StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + + if old_id < new_id: + return new_token + else: + return self + + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": + return attr.evolve(self, **{key: new_value}) + + +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name: str + stream: int + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.get_stream_pos_for_instance(self.instance_name) < self.stream + + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarantees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + + # Deny iteration because it will bite you if you try to create a singleton + # set by: + # users = set(user) + def __iter__(self) -> NoReturn: + raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) + + # Because this class is a frozen class, it is deeply immutable. + def __copy__(self) -> "ThirdPartyInstanceID": + return self + + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": + return self + + @classmethod + def from_string(cls, s: str) -> "ThirdPartyInstanceID": + bits = s.split("|", 2) + if len(bits) != 2: + raise SynapseError(400, "Invalid ID %r" % (s,)) + + return cls(appservice_id=bits[0], network_id=bits[1]) + + def to_string(self) -> str: + return "%s|%s" % (self.appservice_id, self.network_id) + + __str__ = to_string + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ReadReceipt: + """Information about a read-receipt""" + + room_id: str + receipt_type: str + user_id: str + event_ids: List[str] + thread_id: Optional[str] + data: JsonDict + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info: a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + the key ID and verify key for the cross-signing key + """ + # make sure that a `keys` field is provided + if "keys" not in key_info: + raise ValueError("Invalid key") + keys = key_info["keys"] + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) + return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RetentionPolicy: + min_lifetime: Optional[int] = None + max_lifetime: Optional[int] = None diff --git a/synapse/types/state.py b/synapse/types/state.py new file mode 100644 index 0000000000..0004d955b4 --- /dev/null +++ b/synapse/types/state.py @@ -0,0 +1,567 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, +) + +import attr +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types import MutableStateMap, StateKey, StateMap + +if TYPE_CHECKING: + from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad + + +logger = logging.getLogger(__name__) + +# Used for generic functions below +T = TypeVar("T") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StateFilter: + """A filter used when querying for state. + + Attributes: + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not + appear in `types`. + """ + + types: "frozendict[str, Optional[FrozenSet[str]]]" + include_others: bool = False + + def __attrs_post_init__(self) -> None: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + # this is needed to work around the fact that StateFilter is frozen + object.__setattr__( + self, + "types", + frozendict({k: v for k, v in self.types.items() if v is not None}), + ) + + @staticmethod + def all() -> "StateFilter": + """Returns a filter that fetches everything. + + Returns: + The state filter. + """ + return _ALL_STATE_FILTER + + @staticmethod + def none() -> "StateFilter": + """Returns a filter that fetches nothing. + + Returns: + The new state filter. + """ + return _NONE_STATE_FILTER + + @staticmethod + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": + """Creates a filter that only fetches the given types + + Args: + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type + + Returns: + The new state filter. + """ + type_dict: Dict[str, Optional[Set[str]]] = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) # type: ignore + + return StateFilter( + types=frozendict( + (k, frozenset(v) if v is not None else None) + for k, v in type_dict.items() + ) + ) + + @staticmethod + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members: Set of user IDs + + Returns: + The new state filter + """ + return StateFilter( + types=frozendict({EventTypes.Member: frozenset(members)}), + include_others=True, + ) + + @staticmethod + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + + def return_expanded(self) -> "StateFilter": + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + The new state filter. + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in self.types.items() + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + elif EventTypes.Member in self.types: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), + include_others=True, + ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER + + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). + """ + + where_clause = "" + where_args: List[str] = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in self.types.items(): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self) -> Optional[int]: + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: + """Returns the state filtered with by this StateFilter. + + Args: + state: The state map to filter + + Returns: + The filtered state map. + This is a copy, so it's safe to mutate. + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in state_dict.items(): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self) -> bool: + """Whether this filter fetches everything or not + + Returns: + True if the filter fetches everything. + """ + return self.include_others and not self.types + + def has_wildcards(self) -> bool: + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + True if the filter includes wildcards. + """ + + return self.include_others or any( + state_keys is None for state_keys in self.types.values() + ) + + def concrete_types(self) -> List[Tuple[str, str]]: + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + A list of type/state_keys tuples. + """ + return [ + (t, s) + for t, state_keys in self.types.items() + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types=frozendict( + {k: v for k, v in self.types.items() if k != EventTypes.Member} + ), + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + + def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: + """Check if we need to wait for full state to complete to calculate this state + + If we have a state filter which is completely satisfied even with partial + state, then we don't need to await_full_state before we can return it. + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + # if we haven't requested membership events, then it depends on the value of + # 'include_others' + if EventTypes.Member not in self.types: + return self.include_others + + # if we're looking for *all* membership events, then we have to wait + member_state_keys = self.types[EventTypes.Member] + if member_state_keys is None: + return True + + # otherwise, consider whose membership we are looking for. If it's entirely + # local users, then we don't need to wait. + for state_key in member_state_keys: + if not is_mine_id(state_key): + # remote user + return True + + # local users only + return False + + +_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) +_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/visibility.py b/synapse/visibility.py index b443857571..e442de3173 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -26,8 +26,8 @@ from synapse.events.utils import prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import Clock logger = logging.getLogger(__name__) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d4e6d4236c..a433e70870 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -22,8 +22,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.server import HomeServer -from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, StateMap, UserID +from synapse.types.state import StateFilter from synapse.util import Clock from tests.unittest import HomeserverTestCase, TestCase -- cgit 1.5.1 From e2a1adbf5d11288f2134ced1f84c6ffdd91a9357 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 13 Dec 2022 00:54:46 +0000 Subject: Allow selecting "prejoin" events by state keys (#14642) * Declare new config * Parse new config * Read new config * Don't use trial/our TestCase where it's not needed Before: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m2.277s user 0m2.186s sys 0m0.083s ``` After: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m0.566s user 0m0.508s sys 0m0.056s ``` * Helper to upsert to event fields without exceeding size limits. * Use helper when adding invite/knock state Now that we allow admins to include events in prejoin room state with arbitrary state keys, be a good Matrix citizen and ensure they don't accidentally create an oversized event. * Changelog * Move StateFilter tests should have done this in #14668 * Add extra methods to StateFilter * Use StateFilter * Ensure test file enforces typed defs; alphabetise * Workaround surprising get_current_state_ids * Whoops, fix mypy --- changelog.d/14642.feature | 1 + docs/usage/configuration/config_documentation.md | 57 ++- mypy.ini | 12 +- synapse/config/_util.py | 3 + synapse/config/api.py | 63 ++- synapse/events/utils.py | 32 +- synapse/handlers/message.py | 29 +- synapse/storage/databases/main/events_worker.py | 33 +- synapse/types/state.py | 18 + tests/config/test_api.py | 145 ++++++ tests/events/test_utils.py | 35 +- tests/storage/test_state.py | 623 +--------------------- tests/types/__init__.py | 0 tests/types/test_state.py | 627 +++++++++++++++++++++++ 14 files changed, 983 insertions(+), 695 deletions(-) create mode 100644 changelog.d/14642.feature create mode 100644 tests/config/test_api.py create mode 100644 tests/types/__init__.py create mode 100644 tests/types/test_state.py (limited to 'tests') diff --git a/changelog.d/14642.feature b/changelog.d/14642.feature new file mode 100644 index 0000000000..cbc9db10c3 --- /dev/null +++ b/changelog.d/14642.feature @@ -0,0 +1 @@ +Allow selecting "prejoin" events by state keys in addition to event types. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index dc5e5ac597..4d32902fea 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2501,32 +2501,53 @@ Config settings related to the client/server API --- ### `room_prejoin_state` -Controls for the state that is shared with users who receive an invite -to a room. By default, the following state event types are shared with users who -receive invites to the room: -- m.room.join_rules -- m.room.canonical_alias -- m.room.avatar -- m.room.encryption -- m.room.name -- m.room.create -- m.room.topic +This setting controls the state that is shared with users upon receiving an +invite to a room, or in reply to a knock on a room. By default, the following +state events are shared with users: + +- `m.room.join_rules` +- `m.room.canonical_alias` +- `m.room.avatar` +- `m.room.encryption` +- `m.room.name` +- `m.room.create` +- `m.room.topic` To change the default behavior, use the following sub-options: -* `disable_default_event_types`: set to true to disable the above defaults. If this - is enabled, only the event types listed in `additional_event_types` are shared. - Defaults to false. -* `additional_event_types`: Additional state event types to share with users when they are invited - to a room. By default, this list is empty (so only the default event types are shared). +* `disable_default_event_types`: boolean. Set to `true` to disable the above + defaults. If this is enabled, only the event types listed in + `additional_event_types` are shared. Defaults to `false`. +* `additional_event_types`: A list of additional state events to include in the + events to be shared. By default, this list is empty (so only the default event + types are shared). + + Each entry in this list should be either a single string or a list of two + strings. + * A standalone string `t` represents all events with type `t` (i.e. + with no restrictions on state keys). + * A pair of strings `[t, s]` represents a single event with type `t` and + state key `s`. The same type can appear in two entries with different state + keys: in this situation, both state keys are included in prejoin state. Example configuration: ```yaml room_prejoin_state: - disable_default_event_types: true + disable_default_event_types: false additional_event_types: - - org.example.custom.event.type - - m.room.join_rules + # Share all events of type `org.example.custom.event.typeA` + - org.example.custom.event.typeA + # Share only events of type `org.example.custom.event.typeB` whose + # state_key is "foo" + - ["org.example.custom.event.typeB", "foo"] + # Share only events of type `org.example.custom.event.typeC` whose + # state_key is "bar" or "baz" + - ["org.example.custom.event.typeC", "bar"] + - ["org.example.custom.event.typeC", "baz"] ``` + +*Changed in Synapse 1.74:* admins can filter the events in prejoin state based +on their state key. + --- ### `track_puppeted_user_ips` diff --git a/mypy.ini b/mypy.ini index 727536df50..37acf589c9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -89,6 +89,12 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.config.test_api] +disallow_untyped_defs = True + +[mypy-tests.federation.transport.test_client] +disallow_untyped_defs = True + [mypy-tests.handlers.test_sso] disallow_untyped_defs = True @@ -101,7 +107,7 @@ disallow_untyped_defs = True [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True -[mypy-tests.test_server] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.state.test_profile] @@ -110,10 +116,10 @@ disallow_untyped_defs = True [mypy-tests.storage.*] disallow_untyped_defs = True -[mypy-tests.rest.*] +[mypy-tests.test_server] disallow_untyped_defs = True -[mypy-tests.federation.transport.test_client] +[mypy-tests.types.*] disallow_untyped_defs = True [mypy-tests.util.caches.*] diff --git a/synapse/config/_util.py b/synapse/config/_util.py index 3edb4b7106..d3a4b484ab 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -33,6 +33,9 @@ def validate_config( config: the configuration value to be validated config_path: the path within the config file. This will be used as a basis for the error message. + + Raises: + ConfigError, if validation fails. """ try: jsonschema.validate(config, json_schema) diff --git a/synapse/config/api.py b/synapse/config/api.py index e46728e73f..27d50d118f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -13,12 +13,13 @@ # limitations under the License. import logging -from typing import Any, Iterable +from typing import Any, Iterable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError from synapse.config._util import validate_config from synapse.types import JsonDict +from synapse.types.state import StateFilter logger = logging.getLogger(__name__) @@ -26,16 +27,20 @@ logger = logging.getLogger(__name__) class ApiConfig(Config): section = "api" + room_prejoin_state: StateFilter + track_puppetted_users_ips: bool + def read_config(self, config: JsonDict, **kwargs: Any) -> None: validate_config(_MAIN_SCHEMA, config, ()) - self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + self.room_prejoin_state = StateFilter.from_types( + self._get_prejoin_state_entries(config) + ) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) - def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: - """Get the event types to include in the prejoin state - - Parses the config and returns an iterable of the event types to be included. - """ + def _get_prejoin_state_entries( + self, config: JsonDict + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get the event types and state keys to include in the prejoin state.""" room_prejoin_state_config = config.get("room_prejoin_state") or {} # backwards-compatibility support for room_invite_state_types @@ -50,33 +55,39 @@ class ApiConfig(Config): logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) - yield from config["room_invite_state_types"] + for event_type in config["room_invite_state_types"]: + yield event_type, None return if not room_prejoin_state_config.get("disable_default_event_types"): - yield from _DEFAULT_PREJOIN_STATE_TYPES + yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS - yield from room_prejoin_state_config.get("additional_event_types", []) + for entry in room_prejoin_state_config.get("additional_event_types", []): + if isinstance(entry, str): + yield entry, None + else: + yield entry _ROOM_INVITE_STATE_TYPES_WARNING = """\ WARNING: The 'room_invite_state_types' configuration setting is now deprecated, and replaced with 'room_prejoin_state'. New features may not work correctly -unless 'room_invite_state_types' is removed. See the sample configuration file for -details of 'room_prejoin_state'. +unless 'room_invite_state_types' is removed. See the config documentation at + https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state +for details of 'room_prejoin_state'. -------------------------------------------------------------------------------- """ -_DEFAULT_PREJOIN_STATE_TYPES = [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, +_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [ + (EventTypes.JoinRules, ""), + (EventTypes.CanonicalAlias, ""), + (EventTypes.RoomAvatar, ""), + (EventTypes.RoomEncryption, ""), + (EventTypes.Name, ""), # Per MSC1772. - EventTypes.Create, + (EventTypes.Create, ""), # Per MSC3173. - EventTypes.Topic, + (EventTypes.Topic, ""), ] @@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = { "disable_default_event_types": {"type": "boolean"}, "additional_event_types": { "type": "array", - "items": {"type": "string"}, + "items": { + "oneOf": [ + {"type": "string"}, + { + "type": "array", + "items": {"type": "string"}, + "minItems": 2, + "maxItems": 2, + }, + ], + }, }, }, }, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 71853caad8..13fa93afb8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -28,8 +28,14 @@ from typing import ( ) import attr +from canonicaljson import encode_canonical_json -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + MAX_PDU_SIZE, + EventContentFields, + EventTypes, + RelationTypes, +) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict @@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None: elif not isinstance(value, (bool, str)) and value is not None: # Other potential JSON values (bool, None, str) are safe. raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) + + +def maybe_upsert_event_field( + event: EventBase, container: JsonDict, key: str, value: object +) -> bool: + """Upsert an event field, but only if this doesn't make the event too large. + + Returns true iff the upsert took place. + """ + if key in container: + old_value: object = container[key] + container[key] = value + # NB: here and below, we assume that passing a non-None `time_now` argument to + # get_pdu_json doesn't increase the size of the encoded result. + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + container[key] = old_value + else: + container[key] = value + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + del container[key] + + return upsert_okay diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6e90ef259..845f683358 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext +from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -1739,12 +1740,15 @@ class EventCreationHandler: if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, + maybe_upsert_event_field( + event, + event.unsigned, + "invite_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, + ), ) invitee = UserID.from_string(event.state_key) @@ -1762,11 +1766,14 @@ class EventCreationHandler: event.signatures.update(returned_invite.signatures) if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + maybe_upsert_event_field( + event, + event.unsigned, + "knock_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + ), ) if event.type == EventTypes.Redaction: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 01e935edef..318fd7dc71 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -16,11 +16,11 @@ import logging import threading import weakref from enum import Enum, auto +from itertools import chain from typing import ( TYPE_CHECKING, Any, Collection, - Container, Dict, Iterable, List, @@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import ( ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: Container[str], + state_keys_to_include: StateFilter, membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore): Args: context: The event context to retrieve state of the room from. - state_types_to_include: The type of state events to include. + state_keys_to_include: The state events to include, for each event type. membership_user_id: An optional user ID to include the stripped membership state events of. This is useful when generating the stripped state of a room for invites. We want to send membership events of the inviter, so that the @@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore): Returns: A list of dictionaries, each representing a stripped state event from the room. """ - current_state_ids = await context.get_current_state_ids() + if membership_user_id: + types = chain( + state_keys_to_include.to_types(), + [(EventTypes.Member, membership_user_id)], + ) + filter = StateFilter.from_types(types) + else: + filter = state_keys_to_include + selected_state_ids = await context.get_current_state_ids(filter) # We know this event is not an outlier, so this must be # non-None. - assert current_state_ids is not None - - # The state to include - state_to_include_ids = [ - e_id - for k, e_id in current_state_ids.items() - if k[0] in state_types_to_include - or (membership_user_id and k == (EventTypes.Member, membership_user_id)) - ] + assert selected_state_ids is not None + + # Confusingly, get_current_state_events may return events that are discarded by + # the filter, if they're in context._state_delta_due_to_event. Strip these away. + selected_state_ids = filter.filter_state(selected_state_ids) - state_to_include = await self.get_events(state_to_include_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [ { diff --git a/synapse/types/state.py b/synapse/types/state.py index 0004d955b4..743a4f9217 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -118,6 +118,15 @@ class StateFilter: ) ) + def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: + """The inverse to `from_types`.""" + for (event_type, state_keys) in self.types.items(): + if state_keys is None: + yield event_type, None + else: + for state_key in state_keys: + yield event_type, state_key + @staticmethod def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member @@ -343,6 +352,15 @@ class StateFilter: for s in state_keys ] + def wildcard_types(self) -> List[str]: + """Returns a list of event types which require us to fetch all state keys. + This will be empty unless `has_wildcards` returns True. + + Returns: + A list of event types. + """ + return [t for t, state_keys in self.types.items() if state_keys is None] + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching diff --git a/tests/config/test_api.py b/tests/config/test_api.py new file mode 100644 index 0000000000..6773c9a277 --- /dev/null +++ b/tests/config/test_api.py @@ -0,0 +1,145 @@ +from unittest import TestCase as StdlibTestCase + +import yaml + +from synapse.config import ConfigError +from synapse.config.api import ApiConfig +from synapse.types.state import StateFilter + +DEFAULT_PREJOIN_STATE_PAIRS = { + ("m.room.join_rules", ""), + ("m.room.canonical_alias", ""), + ("m.room.avatar", ""), + ("m.room.encryption", ""), + ("m.room.name", ""), + ("m.room.create", ""), + ("m.room.topic", ""), +} + + +class TestRoomPrejoinState(StdlibTestCase): + def read_config(self, source: str) -> ApiConfig: + config = ApiConfig() + config.read_config(yaml.safe_load(source)) + return config + + def test_no_prejoin_state(self) -> None: + config = self.read_config("foo: bar") + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS + ) + + def test_disable_default_event_types(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + """ + ) + self.assertEqual(config.room_prejoin_state, StateFilter.none()) + + def test_event_without_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + def test_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + """ + ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar")}, + ) + + def test_repeated_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - [foo, baz] + """ + ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar"), ("foo", "baz")}, + ) + + def test_no_specific_state_key_overrides_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - foo + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + - [foo, bar] + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + def test_bad_event_type_entry_raises(self) -> None: + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a, b, c] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [true, 1.23] + """ + ) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index b1c47efac7..a79256846f 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest as stdlib_unittest + from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( SerializeEventConfig, copy_and_fixup_power_levels_contents, + maybe_upsert_event_field, prune_event, serialize_event, ) from synapse.util.frozenutils import freeze -from tests import unittest - def MockEvent(**kwargs): if "event_id" not in kwargs: @@ -34,7 +35,31 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) -class PruneEventTestCase(unittest.TestCase): +class TestMaybeUpsertEventField(stdlib_unittest.TestCase): + def test_update_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + success = maybe_upsert_event_field(event, event.unsigned, "key", "value") + self.assertTrue(success) + self.assertEqual(event.unsigned["key"], "value") + + def test_update_not_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertNotIn("key", event.unsigned) + + def test_update_not_okay_leaves_original_value(self) -> None: + event = make_event_from_dict( + {"event_id": "$1234", "unsigned": {"key": "value"}} + ) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertEqual(event.unsigned["key"], "value") + + +class PruneEventTestCase(stdlib_unittest.TestCase): def run_test(self, evdict, matchdict, **kwargs): """ Asserts that a new event constructed with `evdict` will look like @@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase): ) -class SerializeEventTestCase(unittest.TestCase): +class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev, fields): return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) @@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase): ) -class CopyPowerLevelsContentTestCase(unittest.TestCase): +class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: self.test_content = { "ban": 50, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a433e70870..bad7f0bc60 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID from synapse.types.state import StateFilter from synapse.util import Clock -from tests.unittest import HomeserverTestCase, TestCase +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) @@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - - -class StateFilterDifferenceTestCase(TestCase): - def assert_difference( - self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter - ) -> None: - self.assertEqual( - minuend.approx_difference(subtrahend), - expected, - f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", - ) - - def test_state_filter_difference_no_include_other_minus_no_include_other( - self, - ) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b do not have the - include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Create: None}, include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.CanonicalAlias: {""}}, - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only a has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Create: None, - EventTypes.Member: set(), - EventTypes.CanonicalAlias: set(), - }, - include_others=True, - ), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - # This also shows that the resultant state filter is normalised. - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - StateFilter(types=frozendict(), include_others=True), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter( - types=frozendict(), - include_others=True, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.CanonicalAlias: {""}, - EventTypes.Member: set(), - }, - include_others=True, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - def test_state_filter_difference_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b have the include_others - flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - EventTypes.Create: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only b has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=True, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_simple_cases(self) -> None: - """ - Tests some very simple cases of the StateFilter approx_difference, - that are not explicitly tested by the more in-depth tests. - """ - - self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) - - self.assert_difference( - StateFilter.all(), - StateFilter.none(), - StateFilter.all(), - ) - - -class StateFilterTestCase(TestCase): - def test_return_expanded(self) -> None: - """ - Tests the behaviour of the return_expanded() function that expands - StateFilters to include more state types (for the sake of cache hit rate). - """ - - self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) - - self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) - - # Concrete-only state filters stay the same - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ), - ) - - # Concrete-only state filters stay the same - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - {"some.other.state.type": {""}}, include_others=False - ).return_expanded(), - StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), - ) - - # Concrete-only state filters stay the same - # (Case: member-only filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ), - ) - - # Wildcard member-only state filters stay the same - self.assertEqual( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, - include_others=True, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - "yet.another.state.type": {"wombat"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/types/test_state.py b/tests/types/test_state.py new file mode 100644 index 0000000000..eb809f9fb7 --- /dev/null +++ b/tests/types/test_state.py @@ -0,0 +1,627 @@ +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types.state import StateFilter + +from tests.unittest import TestCase + + +class StateFilterDifferenceTestCase(TestCase): + def assert_difference( + self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter + ) -> None: + self.assertEqual( + minuend.approx_difference(subtrahend), + expected, + f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", + ) + + def test_state_filter_difference_no_include_other_minus_no_include_other( + self, + ) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b do not have the + include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Create: None}, include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.CanonicalAlias: {""}}, + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only a has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Create: None, + EventTypes.Member: set(), + EventTypes.CanonicalAlias: set(), + }, + include_others=True, + ), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + # This also shows that the resultant state filter is normalised. + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + StateFilter(types=frozendict(), include_others=True), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter( + types=frozendict(), + include_others=True, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.CanonicalAlias: {""}, + EventTypes.Member: set(), + }, + include_others=True, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + def test_state_filter_difference_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b have the include_others + flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + EventTypes.Create: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only b has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=True, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_simple_cases(self) -> None: + """ + Tests some very simple cases of the StateFilter approx_difference, + that are not explicitly tested by the more in-depth tests. + """ + + self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) + + self.assert_difference( + StateFilter.all(), + StateFilter.none(), + StateFilter.all(), + ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self) -> None: + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) -- cgit 1.5.1