From f694bb71b7ea7841a5b5db3d884dfda5a3f78023 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 9 Sep 2022 11:30:06 -0500 Subject: Strip number suffix from instance name to consolidate services that traces are spread over (#13729) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The problem with many services is that it makes it hard to find which service has the trace you want, see https://github.com/jaegertracing/jaeger-ui/issues/985 Previously, we split traces out into services based on their instance name like `matrix.org client_reader-1`, etc but there are many worker instances of the same `client_reader` so there is a lot to click through. With this PR, all of the traces are just collected under the worker type like `client_reader`, `event_persister` 😇 Note: A Synapse worker instance name is an opaque string with the number convention only being our own thing for the `matrix.org` deployment. But seems pretty sensible to group things this way. --- synapse/logging/opentracing.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 482316a1ff..adf3f54770 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -203,6 +203,9 @@ if TYPE_CHECKING: # Helper class +# Matches the number suffix in an instance name like "matrix.org client_reader-8" +STRIP_INSTANCE_NUMBER_SUFFIX_REGEX = re.compile(r"[_-]?\d+$") + class _DummyTagNames: """wrapper of opentracings tags. We need to have them if we @@ -441,9 +444,17 @@ def init_tracer(hs: "HomeServer") -> None: from jaeger_client.metrics.prometheus import PrometheusMetricsFactory + # Instance names are opaque strings but by stripping off the number suffix, + # we can get something that looks like a "worker type", e.g. + # "client_reader-1" -> "client_reader" so we don't spread the traces across + # so many services. + instance_name_by_type = re.sub( + STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name() + ) + config = JaegerConfig( config=hs.config.tracing.jaeger_config, - service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}", + service_name=f"{hs.config.server.server_name} {instance_name_by_type}", scope_manager=LogContextScopeManager(), metrics_factory=PrometheusMetricsFactory(), ) -- cgit 1.5.1 From a911ffb42cc88adc8084a04acf6fd651efba278f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 9 Sep 2022 11:31:37 -0500 Subject: Tag trace with instance name (#13761) We tag the Synapse instance name so that it's an easy jumping off point into the logs. Can also be used to filter for an instance that is under load. As suggested by @clokep and @reivilibre in, - https://github.com/matrix-org/synapse/pull/13729#discussion_r964719258 - https://github.com/matrix-org/synapse/pull/13729#discussion_r964733578 --- changelog.d/13761.misc | 1 + synapse/api/auth.py | 7 +++++++ synapse/logging/opentracing.py | 6 ++++-- 3 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13761.misc (limited to 'synapse') diff --git a/changelog.d/13761.misc b/changelog.d/13761.misc new file mode 100644 index 0000000000..f7aa8c459a --- /dev/null +++ b/changelog.d/13761.misc @@ -0,0 +1 @@ +Tag traces with the instance name to be able to easily jump into the right logs and filter traces by instance. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 8e54ef84b2..4a75eb6b21 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -32,6 +32,7 @@ from synapse.appservice import ApplicationService from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging.opentracing import ( + SynapseTags, active_span, force_tracing, start_active_span, @@ -161,6 +162,12 @@ class Auth: parent_span.set_tag( "authenticated_entity", requester.authenticated_entity ) + # We tag the Synapse instance name so that it's an easy jumping + # off point into the logs. Can also be used to filter for an + # instance that is under load. + parent_span.set_tag( + SynapseTags.INSTANCE_NAME, self.hs.get_instance_name() + ) parent_span.set_tag("user_id", requester.user.to_string()) if requester.device_id is not None: parent_span.set_tag("device_id", requester.device_id) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index adf3f54770..ca2735dd6d 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -298,6 +298,8 @@ class SynapseTags: # Whether the sync response has new data to be returned to the client. SYNC_RESULT = "sync.new_data" + INSTANCE_NAME = "instance_name" + # incoming HTTP request ID (as written in the logs) REQUEST_ID = "request_id" @@ -1043,11 +1045,11 @@ def trace_servlet( # with JsonResource). scope.span.set_operation_name(request.request_metrics.name) - # set the tags *after* the servlet completes, in case it decided to - # prioritise the span (tags will get dropped on unprioritised spans) request_tags[ SynapseTags.REQUEST_TAG ] = request.request_metrics.start_context.tag + # set the tags *after* the servlet completes, in case it decided to + # prioritise the span (tags will get dropped on unprioritised spans) for k, v in request_tags.items(): scope.span.set_tag(k, v) -- cgit 1.5.1 From 4c4889cac0e6f7df4689287b9fddea1bf8b15b7f Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 9 Sep 2022 19:00:21 +0100 Subject: Concurrently collect room unread counts for push badges (#13765) Most of the time this function is heavily cached, but when that isn't the case fetching the counts room by room slows down push delivery on users with many (thousands) of rooms. Signed off by Nick @ Beeper. --- changelog.d/13765.misc | 1 + synapse/push/push_tools.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13765.misc (limited to 'synapse') diff --git a/changelog.d/13765.misc b/changelog.d/13765.misc new file mode 100644 index 0000000000..fdda5cf3b6 --- /dev/null +++ b/changelog.d/13765.misc @@ -0,0 +1 @@ +Concurrently fetch room push actions when calculating badge counts. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 6661887d9f..658bf373b7 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,6 +17,7 @@ from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore +from synapse.util.async_helpers import concurrently_execute async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: @@ -25,13 +26,19 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge = len(invites) - for room_id in joins: - notifs = await ( - store.get_unread_event_push_actions_by_room_for_user( + room_notifs = [] + + async def get_room_unread_count(room_id: str) -> None: + room_notifs.append( + await store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, ) ) + + await concurrently_execute(get_room_unread_count, joins, 10) + + for notifs in room_notifs: if notifs.notify_count == 0: continue -- cgit 1.5.1 From ebfeac7c5ded851a2639911ec6adf9d0fcdb029a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 12 Sep 2022 11:03:42 +0100 Subject: Check if Rust lib needs rebuilding. (#13759) This protects against the common mistake of failing to remember to rebuild Rust code after making changes. --- changelog.d/13759.misc | 1 + rust/Cargo.toml | 4 ++ rust/build.rs | 45 ++++++++++++++++++++++ rust/src/lib.rs | 10 ++++- stubs/synapse/synapse_rust.pyi | 1 + synapse/__init__.py | 5 +++ synapse/util/rust.py | 84 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13759.misc create mode 100644 rust/build.rs create mode 100644 synapse/util/rust.py (limited to 'synapse') diff --git a/changelog.d/13759.misc b/changelog.d/13759.misc new file mode 100644 index 0000000000..f91c512483 --- /dev/null +++ b/changelog.d/13759.misc @@ -0,0 +1 @@ +Add a check for editable installs if the Rust library needs rebuilding. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 0a9760cafc..deddf3cec2 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,3 +19,7 @@ name = "synapse.synapse_rust" [dependencies] pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] } + +[build-dependencies] +blake2 = "0.10.4" +hex = "0.4.3" diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 0000000000..2117975e56 --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,45 @@ +//! This build script calculates the hash of all files in the `src/` +//! directory and adds it as an environment variable during build time. +//! +//! This is used so that the python code can detect when the built native module +//! does not match the source in-tree, helping to detect the case where the +//! source has been updated but the library hasn't been rebuilt. + +use std::path::PathBuf; + +use blake2::{Blake2b512, Digest}; + +fn main() -> Result<(), std::io::Error> { + let mut dirs = vec![PathBuf::from("src")]; + + let mut paths = Vec::new(); + while let Some(path) = dirs.pop() { + let mut entries = std::fs::read_dir(path)? + .map(|res| res.map(|e| e.path())) + .collect::, std::io::Error>>()?; + + entries.sort(); + + for entry in entries { + if entry.is_dir() { + dirs.push(entry) + } else { + paths.push(entry.to_str().expect("valid rust paths").to_string()); + } + } + } + + paths.sort(); + + let mut hasher = Blake2b512::new(); + + for path in paths { + let bytes = std::fs::read(path)?; + hasher.update(bytes); + } + + let hex_digest = hex::encode(hasher.finalize()); + println!("cargo:rustc-env=SYNAPSE_RUST_DIGEST={hex_digest}"); + + Ok(()) +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 142fc2ed93..ba42465fb8 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,5 +1,13 @@ use pyo3::prelude::*; +/// Returns the hash of all the rust source files at the time it was compiled. +/// +/// Used by python to detect if the rust library is outdated. +#[pyfunction] +fn get_rust_file_digest() -> &'static str { + env!("SYNAPSE_RUST_DIGEST") +} + /// Formats the sum of two numbers as string. #[pyfunction] #[pyo3(text_signature = "(a, b, /)")] @@ -11,6 +19,6 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { #[pymodule] fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; - + m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; Ok(()) } diff --git a/stubs/synapse/synapse_rust.pyi b/stubs/synapse/synapse_rust.pyi index 5b51ba05d7..8658d3138f 100644 --- a/stubs/synapse/synapse_rust.pyi +++ b/stubs/synapse/synapse_rust.pyi @@ -1 +1,2 @@ def sum_as_string(a: int, b: int) -> str: ... +def get_rust_file_digest() -> str: ... diff --git a/synapse/__init__.py b/synapse/__init__.py index b1369aca8f..1bed6393bd 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -20,6 +20,8 @@ import json import os import sys +from synapse.util.rust import check_rust_lib_up_to_date + # Check that we're not running on an unsupported Python version. if sys.version_info < (3, 7): print("Synapse requires Python 3.7 or above.") @@ -78,3 +80,6 @@ if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): from synapse.util.patch_inline_callbacks import do_patch do_patch() + + +check_rust_lib_up_to_date() diff --git a/synapse/util/rust.py b/synapse/util/rust.py new file mode 100644 index 0000000000..30ecb9ffd9 --- /dev/null +++ b/synapse/util/rust.py @@ -0,0 +1,84 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from hashlib import blake2b + +import synapse +from synapse.synapse_rust import get_rust_file_digest + + +def check_rust_lib_up_to_date() -> None: + """For editable installs check if the rust library is outdated and needs to + be rebuilt. + """ + + if not _dist_is_editable(): + return + + synapse_dir = os.path.dirname(synapse.__file__) + synapse_root = os.path.abspath(os.path.join(synapse_dir, "..")) + + # Double check we've not gone into site-packages... + if os.path.basename(synapse_root) == "site-packages": + return + + # ... and it looks like the root of a python project. + if not os.path.exists("pyproject.toml"): + return + + # Get the hash of all Rust source files + hash = _hash_rust_files_in_directory(os.path.join(synapse_root, "rust", "src")) + + if hash != get_rust_file_digest(): + raise Exception("Rust module outdated. Please rebuild using `poetry install`") + + +def _hash_rust_files_in_directory(directory: str) -> str: + """Get the hash of all files in a directory (recursively)""" + + directory = os.path.abspath(directory) + + paths = [] + + dirs = [directory] + while dirs: + dir = dirs.pop() + with os.scandir(dir) as d: + for entry in d: + if entry.is_dir(): + dirs.append(entry.path) + else: + paths.append(entry.path) + + # We sort to make sure that we get a consistent and well-defined ordering. + paths.sort() + + hasher = blake2b() + + for path in paths: + with open(os.path.join(directory, path), "rb") as f: + hasher.update(f.read()) + + return hasher.hexdigest() + + +def _dist_is_editable() -> bool: + """Is distribution an editable install?""" + for path_item in sys.path: + egg_link = os.path.join(path_item, "matrix-synapse.egg-link") + if os.path.isfile(egg_link): + return True + return False -- cgit 1.5.1 From da41a7cd618d11b05c2c04c39068fd4b1e1b7894 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 12 Sep 2022 12:58:33 +0100 Subject: Remove check current state membership up to date (#13745) * Remove checks for membership column in current_state_events * Add schema script to force through the `current_state_events_membership` background job Contributed by Nick @ Beeper (@fizzadar). --- changelog.d/13745.misc | 1 + synapse/storage/databases/main/roommember.py | 202 +++++---------------- ...force_update_current_state_events_membership.py | 52 ++++++ 3 files changed, 100 insertions(+), 155 deletions(-) create mode 100644 changelog.d/13745.misc create mode 100644 synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py (limited to 'synapse') diff --git a/changelog.d/13745.misc b/changelog.d/13745.misc new file mode 100644 index 0000000000..e97a789c0e --- /dev/null +++ b/changelog.d/13745.misc @@ -0,0 +1 @@ +Remove old queries to join room memberships to current state events. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6e1ff5626b..fdb4684e12 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -32,10 +32,7 @@ import attr from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import ( - run_as_background_process, - wrap_as_background_process, -) +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -91,16 +88,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): # at a time. Keyed by room_id. self._joined_host_linearizer = Linearizer("_JoinedHostsCache") - # Is the current_state_events.membership up to date? Or is the - # background update still running? - self._current_state_events_membership_up_to_date = False - - txn = db_conn.cursor( - txn_name="_check_safe_current_state_events_membership_updated" - ) - self._check_safe_current_state_events_membership_updated_txn(txn) - txn.close() - if ( self.hs.config.worker.run_background_tasks and self.hs.config.metrics.metrics_flags.known_servers @@ -157,34 +144,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._known_servers_count = max([count, 1]) return self._known_servers_count - def _check_safe_current_state_events_membership_updated_txn( - self, txn: LoggingTransaction - ) -> None: - """Checks if it is safe to assume the new current_state_events - membership column is up to date - """ - - pending_update = self.db_pool.simple_select_one_txn( - txn, - table="background_updates", - keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, - retcols=["update_name"], - allow_none=True, - ) - - self._current_state_events_membership_up_to_date = not pending_update - - # If the update is still running, reschedule to run. - if pending_update: - self._clock.call_later( - 15.0, - run_as_background_process, - "_check_safe_current_state_events_membership_updated", - self.db_pool.runInteraction, - "_check_safe_current_state_events_membership_updated", - self._check_safe_current_state_events_membership_updated_txn, - ) - @cached(max_entries=100000, iterable=True) async def get_users_in_room(self, room_id: str) -> List[str]: """ @@ -212,31 +171,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): `get_current_hosts_in_room()` and so we can re-use the cache but it's not horrible to have here either. """ - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT c.state_key FROM current_state_events as c - /* Get the depth of the event from the events table */ - INNER JOIN events AS e USING (event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ? - /* Sorted by lowest depth first */ - ORDER BY e.depth ASC; - """ - else: - sql = """ - SELECT c.state_key FROM room_memberships as m - /* Get the depth of the event from the events table */ - INNER JOIN events AS e USING (event_id) - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - /* Sorted by lowest depth first */ - ORDER BY e.depth ASC; - """ + sql = """ + SELECT c.state_key FROM current_state_events as c + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ? + /* Sorted by lowest depth first */ + ORDER BY e.depth ASC; + """ txn.execute(sql, (room_id, Membership.JOIN)) return [r[0] for r in txn] @@ -353,28 +295,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ - else: - sql = """ - SELECT count(*), m.membership FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ txn.execute(sql, (room_id,)) res: Dict[str, MemberSummary] = {} @@ -383,30 +311,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT state_key, membership, event_id - FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - ORDER BY - CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - event_id ASC - LIMIT ? - """ - else: - sql = """ - SELECT c.state_key, m.membership, c.event_id - FROM room_memberships as m - INNER JOIN current_state_events as c USING (room_id, event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - ORDER BY - CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - c.event_id ASC - LIMIT ? - """ + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT state_key, membership, event_id + FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + ORDER BY + CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + event_id ASC + LIMIT ? + """ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) @@ -649,27 +565,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.state_key = ? - AND c.membership = ? - """ - else: - sql = """ - SELECT room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.state_key = ? - AND m.membership = ? - """ + sql = """ + SELECT room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND c.state_key = ? + AND c.membership = ? + """ txn.execute(sql, (user_id, Membership.JOIN)) return frozenset( @@ -707,27 +611,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_ids, ) - if self._current_state_events_membership_up_to_date: - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - else: - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND m.membership = ? - AND {clause} - """ + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND c.membership = ? + AND {clause} + """ txn.execute(sql, [Membership.JOIN] + args) diff --git a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py new file mode 100644 index 0000000000..b5853d125c --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py @@ -0,0 +1,52 @@ +# Copyright 2022 Beeper +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Forces through the `current_state_events_membership` background job so checks +for its completion can be removed. + +Note the background job must still remain defined in the database class. +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + cur.execute("SELECT update_name FROM background_updates") + rows = cur.fetchall() + for row in rows: + if row[0] == "current_state_events_membership": + break + # No pending background job so nothing to do here + else: + return + + # Populate membership field for all current_state_events, this may take + # a while but was originally handled via a background update in 2019. + cur.execute( + """ + UPDATE current_state_events + SET membership = ( + SELECT membership FROM room_memberships + WHERE event_id = current_state_events.event_id + ) + """ + ) + + # Finally, delete the background job because we've handled it above + cur.execute( + """ + DELETE FROM background_updates + WHERE update_name = 'current_state_events_membership' + """ + ) -- cgit 1.5.1 From cdbb6412327b542e0dead792717fe58253291131 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 13 Sep 2022 08:16:37 +0100 Subject: Add receipts event stream ordering (#13703) --- changelog.d/13703.misc | 1 + synapse/_scripts/synapse_port_db.py | 2 + synapse/storage/databases/main/receipts.py | 74 +++++++++++++++++++++- .../delta/72/05receipts_event_stream_ordering.sql | 19 ++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13703.misc create mode 100644 synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql (limited to 'synapse') diff --git a/changelog.d/13703.misc b/changelog.d/13703.misc new file mode 100644 index 0000000000..685a29b17d --- /dev/null +++ b/changelog.d/13703.misc @@ -0,0 +1 @@ +Add & populate `event_stream_ordering` column on receipts table for future optimisation of push action processing. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 543bba27c2..30983c47fb 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -67,6 +67,7 @@ from synapse.storage.databases.main.media_repository import ( ) from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, @@ -203,6 +204,7 @@ class Store( PushRuleStore, PusherWorkerStore, PresenceBackgroundUpdateStore, + ReceiptsBackgroundUpdateStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3838409519..719a12b0ae 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -675,6 +675,7 @@ class ReceiptsWorkerStore(SQLBaseStore): values={ "stream_id": stream_id, "event_id": event_id, + "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on @@ -830,5 +831,76 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -class ReceiptsStore(ReceiptsWorkerStore): +class ReceiptsBackgroundUpdateStore(SQLBaseStore): + POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, + self._populate_receipt_event_stream_ordering, + ) + + async def _populate_receipt_event_stream_ordering( + self, progress: JsonDict, batch_size: int + ) -> int: + def _populate_receipt_event_stream_ordering_txn( + txn: LoggingTransaction, + ) -> bool: + + if "max_stream_id" in progress: + max_stream_id = progress["max_stream_id"] + else: + txn.execute("SELECT max(stream_id) FROM receipts_linearized") + res = txn.fetchone() + if res is None or res[0] is None: + return True + else: + max_stream_id = res[0] + + start = progress.get("stream_id", 0) + stop = start + batch_size + + sql = """ + UPDATE receipts_linearized + SET event_stream_ordering = ( + SELECT stream_ordering + FROM events + WHERE event_id = receipts_linearized.event_id + ) + WHERE stream_id >= ? AND stream_id < ? + """ + txn.execute(sql, (start, stop)) + + self.db_pool.updates._background_update_progress_txn( + txn, + self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, + { + "stream_id": stop, + "max_stream_id": max_stream_id, + }, + ) + + return stop > max_stream_id + + finished = await self.db_pool.runInteraction( + "_remove_devices_from_device_inbox_txn", + _populate_receipt_event_stream_ordering_txn, + ) + + if finished: + await self.db_pool.updates._end_background_update( + self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING + ) + + return batch_size + + +class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql b/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql new file mode 100644 index 0000000000..2a822f4509 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql @@ -0,0 +1,19 @@ +/* Copyright 2022 Beeper + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE receipts_linearized ADD COLUMN event_stream_ordering BIGINT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_event_stream_ordering', '{}'); -- cgit 1.5.1 From b60d47ab2c55580fc1941497964cd33c27838231 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 13 Sep 2022 10:53:11 +0100 Subject: Updates to the schema dump script (#13770) --- changelog.d/13770.misc | 1 + scripts-dev/make_full_schema.sh | 48 ++++++++-------------- .../storage/schema/state/delta/30/state_stream.sql | 4 ++ 3 files changed, 21 insertions(+), 32 deletions(-) create mode 100644 changelog.d/13770.misc (limited to 'synapse') diff --git a/changelog.d/13770.misc b/changelog.d/13770.misc new file mode 100644 index 0000000000..36ac91400a --- /dev/null +++ b/changelog.d/13770.misc @@ -0,0 +1 @@ +Update the script which makes full schema dumps. diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh index f0e22d4ca2..61394360ce 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh @@ -9,8 +9,10 @@ export PGHOST="localhost" POSTGRES_DB_NAME="synapse_full_schema.$$" -SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite" -POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres" +SQLITE_SCHEMA_FILE="schema.sql.sqlite" +SQLITE_ROWS_FILE="rows.sql.sqlite" +POSTGRES_SCHEMA_FILE="full.sql.postgres" +POSTGRES_ROWS_FILE="rows.sql.postgres" REQUIRED_DEPS=("matrix-synapse" "psycopg2") @@ -22,7 +24,7 @@ usage() { echo " Username to connect to local postgres instance. The password will be requested" echo " during script execution." echo "-c" - echo " CI mode. Enables coverage tracking and prints every command that the script runs." + echo " CI mode. Prints every command that the script runs." echo "-o " echo " Directory to output full schema files to." echo "-h" @@ -37,11 +39,6 @@ while getopts "p:co:h" opt; do c) # Print all commands that are being executed set -x - - # Modify required dependencies for coverage - REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess") - - COVERAGE=1 ;; o) command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1) @@ -102,6 +99,7 @@ SQLITE_DB=$TMPDIR/homeserver.db POSTGRES_CONFIG=$TMPDIR/postgres.conf # Ensure these files are delete on script exit +# TODO: the trap should also drop the temp postgres DB trap 'rm -rf $TMPDIR' EXIT cat > "$SQLITE_CONFIG" < "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE" +echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_SCHEMA_FILE' and '$OUTPUT_DIR/$SQLITE_ROWS_FILE'..." +sqlite3 "$SQLITE_DB" ".schema --indent" > "$OUTPUT_DIR/$SQLITE_SCHEMA_FILE" +sqlite3 "$SQLITE_DB" ".dump --data-only --nosys" > "$OUTPUT_DIR/$SQLITE_ROWS_FILE" -echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..." -pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE" +echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_SCHEMA_FILE' and '$OUTPUT_DIR/$POSTGRES_ROWS_FILE'..." +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_DB_NAME" | sed -e '/^$/d' -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_SCHEMA_FILE" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_DB_NAME" | sed -e '/^$/d' -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_ROWS_FILE" echo "Cleaning up temporary Postgres database..." dropdb $POSTGRES_DB_NAME diff --git a/synapse/storage/schema/state/delta/30/state_stream.sql b/synapse/storage/schema/state/delta/30/state_stream.sql index e85699e82e..bdaf8b02d5 100644 --- a/synapse/storage/schema/state/delta/30/state_stream.sql +++ b/synapse/storage/schema/state/delta/30/state_stream.sql @@ -26,6 +26,10 @@ * (event, state) pair, we can use that stream_ordering to identify when * the new state was assigned for the event. */ + +/* NB: This table belongs to the `main` logical database; it should not be present + * in `state`. + */ CREATE TABLE IF NOT EXISTS ex_outlier_stream( event_stream_ordering BIGINT PRIMARY KEY NOT NULL, event_id TEXT NOT NULL, -- cgit 1.5.1 From 12dacecabd27680dc77c17724953ecda0801b5ea Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Tue, 13 Sep 2022 16:14:28 +0200 Subject: Make sequence `cache_invalidation_stream_seq` begin at `2` (#13766) Signed-off-by: Mathieu Velten Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/13766.bugfix | 1 + synapse/storage/schema/__init__.py | 1 + ...8begin_cache_invalidation_seq_at_2.sql.postgres | 23 ++++++++++++++++++++++ 3 files changed, 25 insertions(+) create mode 100644 changelog.d/13766.bugfix create mode 100644 synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres (limited to 'synapse') diff --git a/changelog.d/13766.bugfix b/changelog.d/13766.bugfix new file mode 100644 index 0000000000..c708e54f9c --- /dev/null +++ b/changelog.d/13766.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the `cache_invalidation_stream_seq` sequence would begin at 1 instead of 2. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 256f745dc0..32cda5e3ba 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -76,6 +76,7 @@ Changes in SCHEMA_VERSION = 72: - event_edges.(room_id, is_state) are no longer written to. - Tables related to groups are dropped. - Unused column application_services_state.last_txn is dropped + - Cache invalidation stream id sequence now begins at 2 to match code expectation. """ diff --git a/synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres b/synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres new file mode 100644 index 0000000000..69931fe971 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The sequence needs to begin at 2 because a bunch of code assumes that +-- get_next_id_txn will return values >= 2, cf this comment: +-- https://github.com/matrix-org/synapse/blob/b93bd95e8ab64d27ae26841020f62ee61272a5f2/synapse/storage/util/id_generators.py#L344 + +SELECT setval('cache_invalidation_stream_seq', ( + SELECT COALESCE(MAX(last_value), 1) FROM cache_invalidation_stream_seq +)); -- cgit 1.5.1 From 21687ec189f404bcee98ae61b008afc8c5094400 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 14 Sep 2022 08:28:12 +0000 Subject: Fix a long-standing spec compliance bug where Synapse would accept a trailing slash on the end of `/get_missing_events` federation requests. (#13789) * Don't accept a trailing slash on the end of /get_missing_events * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/13789.bugfix | 1 + synapse/federation/transport/server/federation.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13789.bugfix (limited to 'synapse') diff --git a/changelog.d/13789.bugfix b/changelog.d/13789.bugfix new file mode 100644 index 0000000000..9e1e3e0fa7 --- /dev/null +++ b/changelog.d/13789.bugfix @@ -0,0 +1 @@ +Fix a long-standing spec compliance bug where Synapse would accept a trailing slash on the end of `/get_missing_events` federation requests. \ No newline at end of file diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index f7884bfbe0..6bb4659c4c 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -549,8 +549,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): class FederationGetMissingEventsServlet(BaseFederationServerServlet): - # TODO(paul): Why does this path alone end with "/?" optional? - PATH = "/get_missing_events/(?P[^/]*)/?" + PATH = "/get_missing_events/(?P[^/]*)" async def on_POST( self, -- cgit 1.5.1 From c73774467edb04c372caecb9e843542654f7610b Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 14 Sep 2022 10:42:57 +0100 Subject: Fix bug in device list caching when remote users leave rooms (#13749) When a remote user leaves the last room shared with the homeserver, we have to mark their device list as unsubscribed, otherwise we would hold on to a stale device list in our cache. Crucially, the device list would remain cached even after the remote user rejoined the room, which could lead to E2EE failures until the next change to the remote user's device list. Fixes #13651. Signed-off-by: Sean Quah --- changelog.d/13749.bugfix | 1 + synapse/handlers/device.py | 11 ----------- synapse/handlers/e2e_keys.py | 26 ++++++++++++++++++++++++++ synapse/storage/controllers/persist_events.py | 20 +++++++++++++++++--- tests/handlers/test_e2e_keys.py | 8 +++++++- 5 files changed, 51 insertions(+), 15 deletions(-) create mode 100644 changelog.d/13749.bugfix (limited to 'synapse') diff --git a/changelog.d/13749.bugfix b/changelog.d/13749.bugfix new file mode 100644 index 0000000000..8ffafec07b --- /dev/null +++ b/changelog.d/13749.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where device lists would remain cached when remote users left and rejoined the last room shared with the local homeserver. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c5ac169644..901e2310b7 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -45,7 +45,6 @@ from synapse.types import ( JsonDict, StreamKeyType, StreamToken, - UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -324,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler): self.device_list_updater.incoming_device_list_update, ) - hs.get_distributor().observe("user_left_room", self.user_left_room) - # Whether `_handle_new_device_update_async` is currently processing. self._handle_new_device_update_is_processing = False @@ -569,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler): StreamKeyType.DEVICE_LIST, position, users=[from_user_id] ) - async def user_left_room(self, user: UserID, room_id: str) -> None: - user_id = user.to_string() - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - # We no longer share rooms with this user, so we'll no longer - # receive device updates. Mark this in DB. - await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) - async def store_dehydrated_device( self, user_id: str, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index ec81639c78..8eed63ccf3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -175,6 +175,32 @@ class E2eKeysHandler: user_ids_not_in_cache, remote_results, ) = await self.store.get_user_devices_from_cache(query_list) + + # Check that the homeserver still shares a room with all cached users. + # Note that this check may be slightly racy when a remote user leaves a + # room after we have fetched their cached device list. In the worst case + # we will do extra federation queries for devices that we had cached. + cached_users = set(remote_results.keys()) + valid_cached_users = ( + await self.store.get_users_server_still_shares_room_with( + remote_results.keys() + ) + ) + invalid_cached_users = cached_users - valid_cached_users + if invalid_cached_users: + # Fix up results. If we get here, there is either a bug in device + # list tracking, or we hit the race mentioned above. + user_ids_not_in_cache.update(invalid_cached_users) + for invalid_user_id in invalid_cached_users: + remote_results.pop(invalid_user_id) + # This log message may be removed if it turns out it's almost + # entirely triggered by races. + logger.error( + "Devices for %s were cached, but the server no longer shares " + "any rooms with them. The cached device lists are stale.", + invalid_cached_users, + ) + for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index dad3731b9b..501dbbc990 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -598,9 +598,9 @@ class EventsPersistenceStorageController: # room state_delta_for_room: Dict[str, DeltaState] = {} - # Set of remote users which were in rooms the server has left. We - # should check if we still share any rooms and if not we mark their - # device lists as stale. + # Set of remote users which were in rooms the server has left or who may + # have left rooms the server is in. We should check if we still share any + # rooms and if not we mark their device lists as stale. potentially_left_users: Set[str] = set() if not backfilled: @@ -725,6 +725,20 @@ class EventsPersistenceStorageController: current_state = {} delta.no_longer_in_room = True + # Add all remote users that might have left rooms. + potentially_left_users.update( + user_id + for event_type, user_id in delta.to_delete + if event_type == EventTypes.Member + and not self.is_mine_id(user_id) + ) + potentially_left_users.update( + user_id + for event_type, user_id in delta.to_insert.keys() + if event_type == EventTypes.Member + and not self.is_mine_id(user_id) + ) + state_delta_for_room[room_id] = delta await self.persist_events_store._persist_events_and_state_updates( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 1e6ad4b663..95698bc275 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -891,6 +891,12 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): new_callable=mock.MagicMock, return_value=make_awaitable(["some_room_id"]), ) + mock_get_users = mock.patch.object( + self.store, + "get_users_server_still_shares_room_with", + new_callable=mock.MagicMock, + return_value=make_awaitable({remote_user_id}), + ) mock_request = mock.patch.object( self.hs.get_federation_client(), "query_user_devices", @@ -898,7 +904,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): return_value=make_awaitable(response_body), ) - with mock_get_rooms, mock_request as mocked_federation_request: + with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: # Make the first query and sanity check it succeeds. response_1 = self.get_success( e2e_handler.query_devices( -- cgit 1.5.1 From 51a77e990b7a59e460ab22a2788ab8c3506b9a2c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 14 Sep 2022 14:16:12 +0100 Subject: Remove incorrect migration file from `state` logical DB (#13788) * Remove incorrect migration file from `state` logical DB The table `ex_outlier_stream` is part of the `main` logical DB; it should not have been created in the `state` logical DB. We remove this migration now as a tidy-up. Note: we cannot `DROP TABLE IF EXISTS ex_outlier_stream` in a new migration, because some (most) instances of Synapse host both of these logical DBs on the same DB cluster. * Changelog --- changelog.d/13788.misc | 1 + .../storage/schema/state/delta/30/state_stream.sql | 37 ---------------------- 2 files changed, 1 insertion(+), 37 deletions(-) create mode 100644 changelog.d/13788.misc delete mode 100644 synapse/storage/schema/state/delta/30/state_stream.sql (limited to 'synapse') diff --git a/changelog.d/13788.misc b/changelog.d/13788.misc new file mode 100644 index 0000000000..7263b1ac52 --- /dev/null +++ b/changelog.d/13788.misc @@ -0,0 +1 @@ +Remove an old, incorrect migration file. diff --git a/synapse/storage/schema/state/delta/30/state_stream.sql b/synapse/storage/schema/state/delta/30/state_stream.sql deleted file mode 100644 index bdaf8b02d5..0000000000 --- a/synapse/storage/schema/state/delta/30/state_stream.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/* We used to create a table called current_state_resets, but this is no - * longer used and is removed in delta 54. - */ - -/* The outlier events that have aquired a state group typically through - * backfill. This is tracked separately to the events table, as assigning a - * state group change the position of the existing event in the stream - * ordering. - * However since a stream_ordering is assigned in persist_event for the - * (event, state) pair, we can use that stream_ordering to identify when - * the new state was assigned for the event. - */ - -/* NB: This table belongs to the `main` logical database; it should not be present - * in `state`. - */ -CREATE TABLE IF NOT EXISTS ex_outlier_stream( - event_stream_ordering BIGINT PRIMARY KEY NOT NULL, - event_id TEXT NOT NULL, - state_group BIGINT NOT NULL -); -- cgit 1.5.1 From eaed4e6113f5ed40056fa02ae922cb273d02be6e Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 14 Sep 2022 16:33:54 +0200 Subject: Remove unused method in `synapse.api.auth.Auth`. (#13795) Clean-up from b19060a29b4f73897847db2aba5d03ec819086e0 (#13094) and 73af10f419346a5f2d70131ac1ed8e69942edca0 (#13093) which removed all callers. --- changelog.d/13795.misc | 1 + synapse/api/auth.py | 9 --------- 2 files changed, 1 insertion(+), 9 deletions(-) create mode 100644 changelog.d/13795.misc (limited to 'synapse') diff --git a/changelog.d/13795.misc b/changelog.d/13795.misc new file mode 100644 index 0000000000..20d90cc130 --- /dev/null +++ b/changelog.d/13795.misc @@ -0,0 +1 @@ +Remove unused method in `synapse.api.auth.Auth`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 4a75eb6b21..3d7f986ac7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -459,15 +459,6 @@ class Auth: ) raise InvalidClientTokenError("Invalid access token passed.") - def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: - token = self.get_access_token_from_request(request) - service = self.store.get_app_service_by_token(token) - if not service: - logger.warning("Unrecognised appservice access token.") - raise InvalidClientTokenError() - request.requester = create_requester(service.sender, app_service=service) - return service - async def is_server_admin(self, requester: Requester) -> bool: """Check if the given user is a local server admin. -- cgit 1.5.1 From cf65433de26ecce551c64e56d9ee8435c99defab Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 14 Sep 2022 15:29:05 +0000 Subject: Fix a memory leak when running the unit tests. (#13798) --- changelog.d/13798.misc | 1 + synapse/util/caches/__init__.py | 3 ++- synapse/util/metrics.py | 10 +++++----- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13798.misc (limited to 'synapse') diff --git a/changelog.d/13798.misc b/changelog.d/13798.misc new file mode 100644 index 0000000000..e4ec2d77d6 --- /dev/null +++ b/changelog.d/13798.misc @@ -0,0 +1 @@ +Fix a memory leak when running the unit tests. \ No newline at end of file diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 35c0be08b0..f7c3a6794e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -205,8 +205,9 @@ def register_cache( add_resizable_cache(cache_name, resize_callback) metric = CacheMetric(cache, cache_type, cache_name, collect_callback) + metric_name = "cache_%s_%s" % (cache_type, cache_name) caches_by_name[cache_name] = cache - CACHE_METRIC_REGISTRY.register_hook(metric.collect) + CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect) return metric diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 9687120ebf..165480bdbe 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -15,7 +15,7 @@ import logging from functools import wraps from types import TracebackType -from typing import Awaitable, Callable, Generator, List, Optional, Type, TypeVar +from typing import Awaitable, Callable, Dict, Generator, Optional, Type, TypeVar from prometheus_client import CollectorRegistry, Counter, Metric from typing_extensions import Concatenate, ParamSpec, Protocol @@ -220,21 +220,21 @@ class DynamicCollectorRegistry(CollectorRegistry): def __init__(self) -> None: super().__init__() - self._pre_update_hooks: List[Callable[[], None]] = [] + self._pre_update_hooks: Dict[str, Callable[[], None]] = {} def collect(self) -> Generator[Metric, None, None]: """ Collects metrics, calling pre-update hooks first. """ - for pre_update_hook in self._pre_update_hooks: + for pre_update_hook in self._pre_update_hooks.values(): pre_update_hook() yield from super().collect() - def register_hook(self, hook: Callable[[], None]) -> None: + def register_hook(self, metric_name: str, hook: Callable[[], None]) -> None: """ Registers a hook that is called before metric collection. """ - self._pre_update_hooks.append(hook) + self._pre_update_hooks[metric_name] = hook -- cgit 1.5.1 From 6302753012927b63feddc71dd287e2d3554707d4 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 14 Sep 2022 15:53:18 +0000 Subject: Deduplicate `is_server_notices_room`. (#13780) --- changelog.d/13780.misc | 1 + synapse/handlers/message.py | 10 +--------- synapse/handlers/room_member.py | 10 +--------- synapse/storage/databases/main/roommember.py | 17 +++++++++++++++++ 4 files changed, 20 insertions(+), 18 deletions(-) create mode 100644 changelog.d/13780.misc (limited to 'synapse') diff --git a/changelog.d/13780.misc b/changelog.d/13780.misc new file mode 100644 index 0000000000..1bcac51cad --- /dev/null +++ b/changelog.d/13780.misc @@ -0,0 +1 @@ +Deduplicate `is_server_notices_room`. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72157d5a36..e07cda133a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -752,20 +752,12 @@ class EventCreationHandler: if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return await self._is_server_notices_room(builder.room_id) + return await self.store.is_server_notice_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() return False - async def _is_server_notices_room(self, room_id: str) -> bool: - if self.config.servernotices.server_notices_mxid is None: - return False - is_server_notices_room = await self.store.check_local_user_in_room( - user_id=self.config.servernotices.server_notices_mxid, room_id=room_id - ) - return is_server_notices_room - async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5d4adf5bfd..8d01f4bf2b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -837,7 +837,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = await self._is_server_notice_room(room_id) + is_blocked = await self.store.is_server_notice_room(room_id) if is_blocked: raise SynapseError( HTTPStatus.FORBIDDEN, @@ -1617,14 +1617,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return False - async def _is_server_notice_room(self, room_id: str) -> bool: - if self._server_notices_mxid is None: - return False - is_server_notices_room = await self.store.check_local_user_in_room( - user_id=self._server_notices_mxid, room_id=room_id - ) - return is_server_notices_room - class RoomMemberMasterHandler(RoomMemberHandler): def __init__(self, hs: "HomeServer"): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index fdb4684e12..a8d224602a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -88,6 +88,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): # at a time. Keyed by room_id. self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid + if ( self.hs.config.worker.run_background_tasks and self.hs.config.metrics.metrics_flags.known_servers @@ -504,6 +506,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return membership == Membership.JOIN + async def is_server_notice_room(self, room_id: str) -> bool: + """ + Determines whether the given room is a 'Server Notices' room, used for + sending server notices to a user. + + This is determined by seeing whether the server notices user is present + in the room. + """ + if self._server_notices_mxid is None: + return False + is_server_notices_room = await self.check_local_user_in_room( + user_id=self._server_notices_mxid, room_id=room_id + ) + return is_server_notices_room + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: -- cgit 1.5.1 From f2d12ccabef17faa0bf6b34fbb6d944849afc4d4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Sep 2022 12:01:42 -0400 Subject: Use partial indices on SQLIte. (#13802) Partial indices have been supported since SQLite 3.8, but Synapse now requires >= 3.27, so we can enable support for them. This requires rebuilding previous indices which were partial on PostgreSQL, but not on SQLite. --- changelog.d/13802.misc | 1 + synapse/storage/background_updates.py | 6 +-- .../storage/databases/main/event_push_actions.py | 1 - .../main/delta/72/09partial_indices.sql.sqlite | 56 ++++++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 changelog.d/13802.misc create mode 100644 synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite (limited to 'synapse') diff --git a/changelog.d/13802.misc b/changelog.d/13802.misc new file mode 100644 index 0000000000..0d55071326 --- /dev/null +++ b/changelog.d/13802.misc @@ -0,0 +1 @@ +Use partial indices on SQLite. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 555b4e77d2..cf1eabc437 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -581,9 +581,6 @@ class BackgroundUpdater: def create_index_sqlite(conn: Connection) -> None: # Sqlite doesn't support concurrent creation of indexes. # - # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy and CentOS 7 have 3.7 - # # We assume that sqlite doesn't give us invalid indices; however # we may still end up with the index existing but the # background_updates not having been recorded if synapse got shut @@ -591,12 +588,13 @@ class BackgroundUpdater: # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.) sql = ( "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s" - " (%(columns)s)" + " (%(columns)s) %(where_clause)s" ) % { "unique": "UNIQUE" if unique else "", "name": index_name, "table": table, "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "", } c = conn.cursor() diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f4a07de2a3..3a3fb8c507 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1255,7 +1255,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): table="event_push_actions", columns=["highlight", "stream_ordering"], where_clause="highlight=0", - psql_only=True, ) async def get_push_actions_for_user( diff --git a/synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite b/synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite new file mode 100644 index 0000000000..c8dfdf0218 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite @@ -0,0 +1,56 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- SQLite needs to rebuild indices which use partial indices on Postgres, but +-- previously did not use them on SQLite. + +-- Drop each index that was added with register_background_index_update AND specified +-- a where_clause (that existed before this delta). + +-- From events_bg_updates.py +DROP INDEX IF EXISTS event_contains_url_index; +-- There is also a redactions_censored_redacts index, but that gets dropped. +DROP INDEX IF EXISTS redactions_have_censored_ts; +-- There is also a PostgreSQL only index (event_contains_url_index2) +-- which gets renamed to event_contains_url_index. + +-- From roommember.py +DROP INDEX IF EXISTS room_memberships_user_room_forgotten; + +-- From presence.py +DROP INDEX IF EXISTS presence_stream_state_not_offline_idx; + +-- From media_repository.py +DROP INDEX IF EXISTS local_media_repository_url_idx; + +-- From event_push_actions.py +DROP INDEX IF EXISTS event_push_actions_highlights_index; +-- There's also a event_push_actions_stream_highlight_index which was previously +-- PostgreSQL-only. + +-- From state.py +DROP INDEX IF EXISTS current_state_events_member_index; + +-- Re-insert the background jobs to re-create the indices. +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7209, 'event_contains_url_index', '{}', NULL), + (7209, 'redactions_have_censored_ts_idx', '{}', NULL), + (7209, 'room_membership_forgotten_idx', '{}', NULL), + (7209, 'presence_stream_not_offline_index', '{}', NULL), + (7209, 'local_media_repository_url_idx', '{}', NULL), + (7209, 'event_push_actions_highlights_index', '{}', NULL), + (7209, 'event_push_actions_stream_highlight_index', '{}', NULL), + (7209, 'current_state_members_idx', '{}', NULL) +ON CONFLICT (update_name) DO NOTHING; -- cgit 1.5.1 From 666ae877292d4747b9441105e3df8558f7a335c0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Sep 2022 13:11:16 -0400 Subject: Update event push action and receipt tables to support threads. (#13753) Adds a `thread_id` column to the `event_push_actions`, `event_push_actions_staging`, and `event_push_summary` tables. This will notifications to be segmented by the thread in a future pull request. The `thread_id` column stores the root event ID or the special value `"main"`. The `thread_id` column for `event_push_actions` and `event_push_summary` is backfilled with `"main"` for all existing rows. New entries into `event_push_actions` and `event_push_actions_staging` will get the proper thread ID. `receipts_linearized` and `receipts_graph` also gain a `thread_id` column, which is similar, except `NULL` is a special value meaning the receipt is "unthreaded". See MSC3771 and MSC3773 for where this data will be useful. --- changelog.d/13753.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 29 ++--- .../storage/databases/main/event_push_actions.py | 121 ++++++++++++++++++++- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/receipts.py | 20 ++++ synapse/storage/schema/__init__.py | 6 +- .../main/delta/72/06thread_notifications.sql | 30 +++++ .../main/delta/72/07thread_receipts.sql.postgres | 30 +++++ .../main/delta/72/07thread_receipts.sql.sqlite | 70 ++++++++++++ .../schema/main/delta/72/08thread_receipts.sql | 20 ++++ tests/replication/slave/storage/test_events.py | 1 + 11 files changed, 312 insertions(+), 20 deletions(-) create mode 100644 changelog.d/13753.misc create mode 100644 synapse/storage/schema/main/delta/72/06thread_notifications.sql create mode 100644 synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres create mode 100644 synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite create mode 100644 synapse/storage/schema/main/delta/72/08thread_receipts.sql (limited to 'synapse') diff --git a/changelog.d/13753.misc b/changelog.d/13753.misc new file mode 100644 index 0000000000..63de2eb9f9 --- /dev/null +++ b/changelog.d/13753.misc @@ -0,0 +1 @@ +Prepatory work for storing thread IDs for notifications and receipts. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d1caf8a0f7..3846fbc5f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -198,7 +198,7 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level async def _get_mutual_relations( - self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]] + self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -206,7 +206,7 @@ class BulkPushRuleEvaluator: If the given event has no relation information, returns an empty dictionary. Args: - event_id: The event ID which is targeted by relations. + parent_id: The event ID which is targeted by relations. rules: The push rules which will be processed for this event. Returns: @@ -220,12 +220,6 @@ class BulkPushRuleEvaluator: if not self._relations_match_enabled: return {} - # If the event does not have a relation, then cannot have any mutual - # relations. - relation = relation_from_event(event) - if not relation: - return {} - # Pre-filter to figure out which relation types are interesting. rel_types = set() for rule, enabled in rules: @@ -246,9 +240,7 @@ class BulkPushRuleEvaluator: return {} # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations( - relation.parent_id, rel_types - ) + return await self.store.get_mutual_event_relations(parent_id, rel_types) @measure_func("action_for_event_by_user") async def action_for_event_by_user( @@ -281,9 +273,17 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) - relations = await self._get_mutual_relations( - event, itertools.chain(*rules_by_user.values()) - ) + relation = relation_from_event(event) + # If the event does not have a relation, then cannot have any mutual + # relations or thread ID. + relations = {} + thread_id = "main" + if relation: + relations = await self._get_mutual_relations( + relation.parent_id, itertools.chain(*rules_by_user.values()) + ) + if relation.rel_type == RelationTypes.THREAD: + thread_id = relation.parent_id evaluator = PushRuleEvaluatorForEvent( event, @@ -352,6 +352,7 @@ class BulkPushRuleEvaluator: event.event_id, actions_by_user, count_as_unread, + thread_id, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a3fb8c507..6b8668d2dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -98,6 +98,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas replaces_index="event_push_summary_user_rm", ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_unique_index2", + index_name="event_push_summary_unique_index2", + table="event_push_summary", + columns=["user_id", "room_id", "thread_id"], + unique=True, + ) + + self.db_pool.updates.register_background_update_handler( + "event_push_backfill_thread_id", + self._background_backfill_thread_id, + ) + + async def _background_backfill_thread_id( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Fill in the thread_id field for event_push_actions and event_push_summary. + + This is preparatory so that it can be made non-nullable in the future. + + Because all current (null) data is done in an unthreaded manner this + simply assumes it is on the "main" timeline. Since event_push_actions + are periodically cleared it is not possible to correctly re-calculate + the thread_id. + """ + event_push_actions_done = progress.get("event_push_actions_done", False) + + def add_thread_id_txn( + txn: LoggingTransaction, table_name: str, start_stream_ordering: int + ) -> int: + sql = f""" + SELECT stream_ordering + FROM {table_name} + WHERE + thread_id IS NULL + AND stream_ordering > ? + ORDER BY stream_ordering + LIMIT ? + """ + txn.execute(sql, (start_stream_ordering, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + progress[f"{table_name}_done"] = True + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + return 0 + + # Update the thread ID for any of those rows. + max_stream_ordering = rows[-1][0] + + sql = f""" + UPDATE {table_name} + SET thread_id = 'main' + WHERE stream_ordering <= ? AND thread_id IS NULL + """ + txn.execute(sql, (max_stream_ordering,)) + + # Update progress. + processed_rows = txn.rowcount + progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + # First update the event_push_actions table, then the event_push_summary table. + # + # Note that the event_push_actions_staging table is ignored since it is + # assumed that items in that table will only exist for a short period of + # time. + if not event_push_actions_done: + result = await self.db_pool.runInteraction( + "event_push_backfill_thread_id", + add_thread_id_txn, + "event_push_actions", + progress.get("max_event_push_actions_stream_ordering", 0), + ) + else: + result = await self.db_pool.runInteraction( + "event_push_backfill_thread_id", + add_thread_id_txn, + "event_push_summary", + progress.get("max_event_push_summary_stream_ordering", 0), + ) + + # Only done after the event_push_summary table is done. + if not result: + await self.db_pool.updates._end_background_update( + "event_push_backfill_thread_id" + ) + + return result + @cached(tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( self, @@ -670,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_id: str, user_id_actions: Dict[str, Collection[Union[Mapping, str]]], count_as_unread: bool, + thread_id: str, ) -> None: """Add the push actions for the event to the push action staging area. @@ -678,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id_actions: A mapping of user_id to list of push actions, where an action can either be a string or dict. count_as_unread: Whether this event should increment unread counts. + thread_id: The thread this event is parent of, if applicable. """ if not user_id_actions: return @@ -686,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( user_id: str, actions: Collection[Union[Mapping, str]] - ) -> Tuple[str, str, str, int, int, int]: + ) -> Tuple[str, str, str, int, int, int, str]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( @@ -696,11 +797,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column + thread_id, # thread_id column ) await self.db_pool.simple_insert_many( "event_push_actions_staging", - keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"), + keys=( + "event_id", + "user_id", + "actions", + "notif", + "highlight", + "unread", + "thread_id", + ), values=[ _gen_entry(user_id, actions) for user_id, actions in user_id_actions.items() @@ -981,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) # Replace the previous summary with the new counts. + # + # TODO(threads): Upsert per-thread instead of setting them all to main. self.db_pool.simple_upsert_txn( txn, table="event_push_summary", @@ -990,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "unread_count": unread_count, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, + "thread_id": "main", }, ) @@ -1138,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", key_names=("user_id", "room_id"), key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering"), + value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, + "main", ) for summary in summaries.values() ], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a4010ee28d..c0b4080e4b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2192,9 +2192,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight, unread + topological_ordering, notif, highlight, unread, thread_id ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id FROM event_push_actions_staging WHERE event_id = ? """ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 719a12b0ae..ddb8e80b69 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,6 +113,24 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) + self.db_pool.updates.register_background_index_update( + "receipts_linearized_unique_index", + index_name="receipts_linearized_unique_index", + table="receipts_linearized", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + + self.db_pool.updates.register_background_index_update( + "receipts_graph_unique_index", + index_name="receipts_graph_unique_index", + table="receipts_graph", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -677,6 +695,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), + "thread_id": None, }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -824,6 +843,7 @@ class ReceiptsWorkerStore(SQLBaseStore): values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), + "thread_id": None, }, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 32cda5e3ba..38c9532bfd 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 72 # remember to update the list below when updating +SCHEMA_VERSION = 73 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -77,6 +77,10 @@ Changes in SCHEMA_VERSION = 72: - Tables related to groups are dropped. - Unused column application_services_state.last_txn is dropped - Cache invalidation stream id sequence now begins at 2 to match code expectation. + +Changes in SCHEMA_VERSION = 73; + - thread_id column is added to event_push_actions, event_push_actions_staging + event_push_summary, receipts_linearized, and receipts_graph. """ diff --git a/synapse/storage/schema/main/delta/72/06thread_notifications.sql b/synapse/storage/schema/main/delta/72/06thread_notifications.sql new file mode 100644 index 0000000000..2f4f5dac7a --- /dev/null +++ b/synapse/storage/schema/main/delta/72/06thread_notifications.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a nullable column for thread ID to the event push actions tables; this +-- will be filled in with a default value for any previously existing rows. +-- +-- After migration this can be made non-nullable. + +ALTER TABLE event_push_actions_staging ADD COLUMN thread_id TEXT; +ALTER TABLE event_push_actions ADD COLUMN thread_id TEXT; +ALTER TABLE event_push_summary ADD COLUMN thread_id TEXT; + +-- Update the unique index for `event_push_summary`. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7006, 'event_push_summary_unique_index2', '{}'); + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7006, 'event_push_backfill_thread_id', '{}', 'event_push_summary_unique_index2'); diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres new file mode 100644 index 0000000000..55fff9e278 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a nullable column for thread ID to the receipts table; this allows a +-- receipt per user, per room, as well as an unthreaded receipt (corresponding +-- to a null thread ID). + +ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT; +ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT; + +-- Rebuild the unique constraint with the thread_id. +ALTER TABLE receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness_thread + UNIQUE (room_id, receipt_type, user_id, thread_id); + +ALTER TABLE receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness_thread + UNIQUE (room_id, receipt_type, user_id, thread_id); diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite new file mode 100644 index 0000000000..232f67deb4 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite @@ -0,0 +1,70 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Allow multiple receipts per user per room via a nullable thread_id column. +-- +-- SQLite doesn't support modifying constraints to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE receipts_linearized_new ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + thread_id TEXT, + event_stream_ordering BIGINT, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +CREATE TABLE receipts_graph_new ( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + thread_id TEXT, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +-- Drop the old indexes. +DROP INDEX IF EXISTS receipts_linearized_id; +DROP INDEX IF EXISTS receipts_linearized_room_stream; +DROP INDEX IF EXISTS receipts_linearized_user; + +-- Copy the data. +INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data) + SELECT stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data + FROM receipts_linearized; +INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data) + SELECT room_id, receipt_type, user_id, event_ids, data + FROM receipts_graph; + +-- Drop the old tables. +DROP TABLE receipts_linearized; +DROP TABLE receipts_graph; + +-- Rename the tables. +ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized; +ALTER TABLE receipts_graph_new RENAME TO receipts_graph; + +-- Create the indices. +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); diff --git a/synapse/storage/schema/main/delta/72/08thread_receipts.sql b/synapse/storage/schema/main/delta/72/08thread_receipts.sql new file mode 100644 index 0000000000..e35b021f31 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/08thread_receipts.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7007, 'receipts_linearized_unique_index', '{}'); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7007, 'receipts_graph_unique_index', '{}'); diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 531a0db2d0..49a21e2e85 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event.event_id, {user_id: actions for user_id, actions in push_actions}, False, + "main", ) ) return event, context -- cgit 1.5.1 From 957e3d74fc70f92bb9ed3c709f87752bf77a8c90 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 14 Sep 2022 13:57:50 -0500 Subject: Keep track when we try and fail to process a pulled event (#13589) We can follow-up this PR with: 1. Only try to backfill from an event if we haven't tried recently -> https://github.com/matrix-org/synapse/issues/13622 1. When we decide to backfill that event again, process it in the background so it doesn't block and make `/messages` slow when we know it will probably fail again -> https://github.com/matrix-org/synapse/issues/13623 1. Generally track failures everywhere we try and fail to pull an event over federation -> https://github.com/matrix-org/synapse/issues/13700 Fix https://github.com/matrix-org/synapse/issues/13621 Part of https://github.com/matrix-org/synapse/issues/13356 Mentioned in [internal doc](https://docs.google.com/document/d/1lvUoVfYUiy6UaHB6Rb4HicjaJAU40-APue9Q4vzuW3c/edit#bookmark=id.qv7cj51sv9i5) --- changelog.d/13589.feature | 1 + synapse/handlers/federation_event.py | 7 + synapse/storage/databases/main/event_federation.py | 45 +++++ synapse/storage/databases/main/events.py | 32 ++- synapse/storage/schema/__init__.py | 2 + .../main/delta/73/01event_failed_pull_attempts.sql | 29 +++ tests/handlers/test_federation_event.py | 222 +++++++++++++++++++++ 7 files changed, 329 insertions(+), 9 deletions(-) create mode 100644 changelog.d/13589.feature create mode 100644 synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql (limited to 'synapse') diff --git a/changelog.d/13589.feature b/changelog.d/13589.feature new file mode 100644 index 0000000000..78fa1ddb52 --- /dev/null +++ b/changelog.d/13589.feature @@ -0,0 +1 @@ +Keep track when we attempt to backfill an event but fail so we can intelligently back-off in the future. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index ace7adcffb..9e065e1116 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -862,6 +862,9 @@ class FederationEventHandler: self._sanity_check_event(event) except SynapseError as err: logger.warning("Event %s failed sanity check: %s", event_id, err) + await self._store.record_event_failed_pull_attempt( + event.room_id, event_id, str(err) + ) return try: @@ -897,6 +900,10 @@ class FederationEventHandler: backfilled=backfilled, ) except FederationError as e: + await self._store.record_event_failed_pull_attempt( + event.room_id, event_id, str(e) + ) + if e.code == 403: logger.warning("Pulled event %s failed history check.", event_id) else: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca47a22bf1..ef477978ed 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1294,6 +1294,51 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return event_id_results + @trace + async def record_event_failed_pull_attempt( + self, room_id: str, event_id: str, cause: str + ) -> None: + """ + Record when we fail to pull an event over federation. + + This information allows us to be more intelligent when we decide to + retry (we don't need to fail over and over) and we can process that + event in the background so we don't block on it each time. + + Args: + room_id: The room where the event failed to pull from + event_id: The event that failed to be fetched or processed + cause: The error message or reason that we failed to pull the event + """ + await self.db_pool.runInteraction( + "record_event_failed_pull_attempt", + self._record_event_failed_pull_attempt_upsert_txn, + room_id, + event_id, + cause, + db_autocommit=True, # Safe as it's a single upsert + ) + + def _record_event_failed_pull_attempt_upsert_txn( + self, + txn: LoggingTransaction, + room_id: str, + event_id: str, + cause: str, + ) -> None: + sql = """ + INSERT INTO event_failed_pull_attempts ( + room_id, event_id, num_attempts, last_attempt_ts, last_cause + ) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, event_id) DO UPDATE SET + num_attempts=event_failed_pull_attempts.num_attempts + 1, + last_attempt_ts=EXCLUDED.last_attempt_ts, + last_cause=EXCLUDED.last_cause; + """ + + txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + async def get_missing_events( self, room_id: str, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c0b4080e4b..1b54a2eb57 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2435,17 +2435,31 @@ class PersistEventsStore: "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" ) + backward_extremity_tuples_to_remove = [ + (ev.event_id, ev.room_id) + for ev in events + if not ev.internal_metadata.is_outlier() + # If we encountered an event with no prev_events, then we might + # as well remove it now because it won't ever have anything else + # to backfill from. + or len(ev.prev_event_ids()) == 0 + ] txn.execute_batch( query, - [ - (ev.event_id, ev.room_id) - for ev in events - if not ev.internal_metadata.is_outlier() - # If we encountered an event with no prev_events, then we might - # as well remove it now because it won't ever have anything else - # to backfill from. - or len(ev.prev_event_ids()) == 0 - ], + backward_extremity_tuples_to_remove, + ) + + # Clear out the failed backfill attempts after we successfully pulled + # the event. Since we no longer need these events as backward + # extremities, it also means that they won't be backfilled from again so + # we no longer need to store the backfill attempts around it. + query = """ + DELETE FROM event_failed_pull_attempts + WHERE event_id = ? and room_id = ? + """ + txn.execute_batch( + query, + backward_extremity_tuples_to_remove, ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 38c9532bfd..68e055c664 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -81,6 +81,8 @@ Changes in SCHEMA_VERSION = 72: Changes in SCHEMA_VERSION = 73; - thread_id column is added to event_push_actions, event_push_actions_staging event_push_summary, receipts_linearized, and receipts_graph. + - Add table `event_failed_pull_attempts` to keep track when we fail to pull + events over federation. """ diff --git a/synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql b/synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql new file mode 100644 index 0000000000..d397ee1082 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Add a table that keeps track of when we failed to pull an event over +-- federation (via /backfill, `/event`, `/get_missing_events`, etc). This allows +-- us to be more intelligent when we decide to retry (we don't need to fail over +-- and over) and we can process that event in the background so we don't block +-- on it each time. +CREATE TABLE IF NOT EXISTS event_failed_pull_attempts( + room_id TEXT NOT NULL REFERENCES rooms (room_id), + event_id TEXT NOT NULL, + num_attempts INT NOT NULL, + last_attempt_ts BIGINT NOT NULL, + last_cause TEXT NOT NULL, + PRIMARY KEY (room_id, event_id) +); diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 51c8dd6498..b5b89405a4 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -227,3 +227,225 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): if prev_exists_as_outlier: self.mock_federation_transport_client.get_event.assert_not_called() + + def test_process_pulled_event_records_failed_backfill_attempts( + self, + ) -> None: + """ + Test to make sure that failed backfill attempts for an event are + recorded in the `event_failed_pull_attempts` table. + + In this test, we pretend we are processing a "pulled" event via + backfill. The pulled event has a fake `prev_event` which our server has + obviously never seen before so it attempts to request the state at that + `prev_event` which expectedly fails because it's a fake event. Because + the server can't fetch the state at the missing `prev_event`, the + "pulled" event fails the history check and is fails to process. + + We check that we correctly record the number of failed pull attempts + of the pulled event and as a sanity check, that the "pulled" event isn't + persisted. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # We expect an outbound request to /state_ids, so stub that out + self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable( + { + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + "pdu_ids": [], + "auth_chain_ids": [], + } + ) + # We also expect an outbound request to /state + self.mock_federation_transport_client.get_room_state.return_value = make_awaitable( + StateRequestResponse( + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + auth_events=[], + state=[], + ) + ) + + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + # The fake prev event will make the pulled event fail + # the history check (`Unable to get missing prev_event + # $fake_prev_event`) + "$fake_prev_event" + ], + "auth_events": [], + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled"}, + } + ), + room_version, + ) + + # The function under test: try to process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, pulled_event, backfilled=True + ) + ) + + # Make sure our failed pull attempt was recorded + backfill_num_attempts = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts, 1) + + # The function under test: try to process the pulled event again + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, pulled_event, backfilled=True + ) + ) + + # Make sure our second failed pull attempt was recorded (`num_attempts` was incremented) + backfill_num_attempts = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts, 2) + + # And as a sanity check, make sure the event was not persisted through all of this. + persisted = self.get_success( + main_store.get_event(pulled_event.event_id, allow_none=True) + ) + self.assertIsNone( + persisted, + "pulled event that fails the history check should not be persisted at all", + ) + + def test_process_pulled_event_clears_backfill_attempts_after_being_successfully_persisted( + self, + ) -> None: + """ + Test to make sure that failed pull attempts + (`event_failed_pull_attempts` table) for an event are cleared after the + event is successfully persisted. + + In this test, we pretend we are processing a "pulled" event via + backfill. The pulled event succesfully processes and the backward + extremeties are updated along with clearing out any failed pull attempts + for those old extremities. + + We check that we correctly cleared failed pull attempts of the + pulled event. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + member_event.event_id, + ] + + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [member_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled"}, + } + ), + room_version, + ) + + # Fake the "pulled" event failing to backfill once so we can test + # if it's cleared out later on. + self.get_success( + main_store.record_event_failed_pull_attempt( + pulled_event.room_id, pulled_event.event_id, "fake cause" + ) + ) + # Make sure we have a failed pull attempt recorded for the pulled event + backfill_num_attempts = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts, 1) + + # The function under test: try to process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, pulled_event, backfilled=True + ) + ) + + # Make sure the failed pull attempts for the pulled event are cleared + backfill_num_attempts = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + allow_none=True, + ) + ) + self.assertIsNone(backfill_num_attempts) + + # And as a sanity check, make sure the "pulled" event was persisted. + persisted = self.get_success( + main_store.get_event(pulled_event.event_id, allow_none=True) + ) + self.assertIsNotNone(persisted, "pulled event was not persisted at all") -- cgit 1.5.1 From 918c74bfb57e3ca4d300ed9a3bfb99b99126f821 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 15 Sep 2022 13:57:16 +0100 Subject: Add a `MXCUri` class to make working with mxc uri's easier. (#13162) --- changelog.d/13162.misc | 1 + poetry.lock | 10 +-- pyproject.toml | 2 +- synapse/rest/media/v1/media_repository.py | 6 +- synapse/rest/media/v1/upload_resource.py | 6 +- tests/rest/media/test_media_retention.py | 102 +++++++++++------------------- 6 files changed, 53 insertions(+), 74 deletions(-) create mode 100644 changelog.d/13162.misc (limited to 'synapse') diff --git a/changelog.d/13162.misc b/changelog.d/13162.misc new file mode 100644 index 0000000000..b0d7c05e74 --- /dev/null +++ b/changelog.d/13162.misc @@ -0,0 +1 @@ +Bump the minimum dependency of `matrix_common` to 1.3.0 to make use of the `MXCUri` class. Use `MXCUri` to simplify media retention test code. \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index cdc69f8ea9..291f3c51e6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -524,11 +524,11 @@ python-versions = ">=3.7" [[package]] name = "matrix-common" -version = "1.2.1" +version = "1.3.0" description = "Common utilities for Synapse, Sydent and Sygnal" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] attrs = "*" @@ -1625,7 +1625,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "79cfa09d59f9f8b5ef24318fb860df1915f54328692aa56d04331ecbdd92a8cb" +content-hash = "1b14fc274d9e2a495a7f864150f3ffcf4d9f585e09a67e53301ae4ef3c2f3e48" [metadata.files] attrs = [ @@ -2113,8 +2113,8 @@ markupsafe = [ {file = "MarkupSafe-2.1.0.tar.gz", hash = "sha256:80beaf63ddfbc64a0452b841d8036ca0611e049650e20afcb882f5d3c266d65f"}, ] matrix-common = [ - {file = "matrix_common-1.2.1-py3-none-any.whl", hash = "sha256:946709c405944a0d4b1d73207b77eb064b6dbfc5d70a69471320b06d8ce98b20"}, - {file = "matrix_common-1.2.1.tar.gz", hash = "sha256:a99dcf02a6bd95b24a5a61b354888a2ac92bf2b4b839c727b8dd9da2cdfa3853"}, + {file = "matrix_common-1.3.0-py3-none-any.whl", hash = "sha256:524e2785b9b03be4d15f3a8a6b857c5b6af68791ffb1b9918f0ad299abc4db20"}, + {file = "matrix_common-1.3.0.tar.gz", hash = "sha256:62e121cccd9f243417b57ec37a76dc44aeb198a7a5c67afd6b8275992ff2abd1"}, ] matrix-synapse-ldap3 = [ {file = "matrix-synapse-ldap3-0.2.2.tar.gz", hash = "sha256:b388d95693486eef69adaefd0fd9e84463d52fe17b0214a00efcaa669b73cb74"}, diff --git a/pyproject.toml b/pyproject.toml index 157385ad8a..8e50dd2852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,7 +164,7 @@ typing-extensions = ">=3.10.0.1" cryptography = ">=3.4.7" # ijson 3.1.4 fixes a bug with "." in property names ijson = ">=3.1.4" -matrix-common = "^1.2.1" +matrix-common = "^1.3.0" # We need packaging.requirements.Requirement, added in 16.1. packaging = ">=16.1" # At the time of writing, we only use functions from the version `importlib.metadata` diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9dd3c8d4bb..328c0c5477 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -19,6 +19,8 @@ import shutil from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from matrix_common.types.mxc_uri import MXCUri + import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred @@ -186,7 +188,7 @@ class MediaRepository: content: IO, content_length: int, auth_user: UserID, - ) -> str: + ) -> MXCUri: """Store uploaded content for a local user and return the mxc URL Args: @@ -219,7 +221,7 @@ class MediaRepository: await self._generate_thumbnails(None, media_id, media_id, media_type) - return "mxc://%s/%s" % (self.server_name, media_id) + return MXCUri(self.server_name, media_id) async def get_local_media( self, request: SynapseRequest, media_id: str, name: Optional[str] diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index e73e431dc9..97548b54e5 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -101,6 +101,8 @@ class UploadResource(DirectServeJsonResource): # the default 404, as that would just be confusing. raise SynapseError(400, "Bad content") - logger.info("Uploaded content with URI %r", content_uri) + logger.info("Uploaded content with URI '%s'", content_uri) - respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True) + respond_with_json( + request, 200, {"content_uri": str(content_uri)}, send_cors=True + ) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index 14af07c5af..23f227aed6 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -13,7 +13,9 @@ # limitations under the License. import io -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional + +from matrix_common.types.mxc_uri import MXCUri from twisted.test.proto_helpers import MemoryReactor @@ -63,9 +65,9 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): last_accessed_ms: Optional[int], is_quarantined: Optional[bool] = False, is_protected: Optional[bool] = False, - ) -> str: + ) -> MXCUri: # "Upload" some media to the local media store - mxc_uri = self.get_success( + mxc_uri: MXCUri = self.get_success( media_repository.create_content( media_type="text/plain", upload_name=None, @@ -75,13 +77,11 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): ) ) - media_id = mxc_uri.split("/")[-1] - # Set the last recently accessed time for this media if last_accessed_ms is not None: self.get_success( self.store.update_cached_last_access_time( - local_media=(media_id,), + local_media=(mxc_uri.media_id,), remote_media=(), time_ms=last_accessed_ms, ) @@ -92,7 +92,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): self.get_success( self.store.quarantine_media_by_id( server_name=self.hs.config.server.server_name, - media_id=media_id, + media_id=mxc_uri.media_id, quarantined_by="@theadmin:test", ) ) @@ -101,18 +101,18 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # Mark this media as protected from quarantine self.get_success( self.store.mark_local_media_as_safe( - media_id=media_id, + media_id=mxc_uri.media_id, safe=True, ) ) - return media_id + return mxc_uri def _cache_remote_media_and_set_attributes( media_id: str, last_accessed_ms: Optional[int], is_quarantined: Optional[bool] = False, - ) -> str: + ) -> MXCUri: # Pretend to cache some remote media self.get_success( self.store.store_cached_remote_media( @@ -146,7 +146,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): ) ) - return media_id + return MXCUri(self.remote_server_name, media_id) # Start with the local media store self.local_recently_accessed_media = _create_media_and_set_attributes( @@ -214,28 +214,16 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # Remote media should be unaffected. self._assert_if_mxc_uris_purged( purged=[ - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_media, - ), - (self.hs.config.server.server_name, self.local_never_accessed_media), + self.local_not_recently_accessed_media, + self.local_never_accessed_media, ], not_purged=[ - (self.hs.config.server.server_name, self.local_recently_accessed_media), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_quarantined_media, - ), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_protected_media, - ), - (self.remote_server_name, self.remote_recently_accessed_media), - (self.remote_server_name, self.remote_not_recently_accessed_media), - ( - self.remote_server_name, - self.remote_not_recently_accessed_quarantined_media, - ), + self.local_recently_accessed_media, + self.local_not_recently_accessed_quarantined_media, + self.local_not_recently_accessed_protected_media, + self.remote_recently_accessed_media, + self.remote_not_recently_accessed_media, + self.remote_not_recently_accessed_quarantined_media, ], ) @@ -261,49 +249,35 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # Remote media accessed <30 days ago should still exist. self._assert_if_mxc_uris_purged( purged=[ - (self.remote_server_name, self.remote_not_recently_accessed_media), + self.remote_not_recently_accessed_media, ], not_purged=[ - (self.remote_server_name, self.remote_recently_accessed_media), - (self.hs.config.server.server_name, self.local_recently_accessed_media), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_media, - ), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_quarantined_media, - ), - ( - self.hs.config.server.server_name, - self.local_not_recently_accessed_protected_media, - ), - ( - self.remote_server_name, - self.remote_not_recently_accessed_quarantined_media, - ), - (self.hs.config.server.server_name, self.local_never_accessed_media), + self.remote_recently_accessed_media, + self.local_recently_accessed_media, + self.local_not_recently_accessed_media, + self.local_not_recently_accessed_quarantined_media, + self.local_not_recently_accessed_protected_media, + self.remote_not_recently_accessed_quarantined_media, + self.local_never_accessed_media, ], ) def _assert_if_mxc_uris_purged( - self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]] + self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri] ) -> None: - def _assert_mxc_uri_purge_state( - server_name: str, media_id: str, expect_purged: bool - ) -> None: + def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: """Given an MXC URI, assert whether it has been purged or not.""" - if server_name == self.hs.config.server.server_name: + if mxc_uri.server_name == self.hs.config.server.server_name: found_media_dict = self.get_success( - self.store.get_local_media(media_id) + self.store.get_local_media(mxc_uri.media_id) ) else: found_media_dict = self.get_success( - self.store.get_cached_remote_media(server_name, media_id) + self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) ) - mxc_uri = f"mxc://{server_name}/{media_id}" - if expect_purged: self.assertIsNone( found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" @@ -315,7 +289,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): ) # Assert that the given MXC URIs have either been correctly purged or not. - for server_name, media_id in purged: - _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True) - for server_name, media_id in not_purged: - _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False) + for mxc_uri in purged: + _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True) + for mxc_uri in not_purged: + _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False) -- cgit 1.5.1 From 742f9f9d78490f7f16bdb607a8f61ca258d520ef Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 15 Sep 2022 18:36:02 +0100 Subject: A third batch of Pydantic validation for rest/client/account.py (#13736) --- changelog.d/13736.feature | 1 + synapse/rest/client/account.py | 65 ++++++++++++++++++++++------------------ synapse/rest/client/models.py | 28 +++++++++-------- tests/rest/client/test_models.py | 29 ++++++++++++++++-- 4 files changed, 78 insertions(+), 45 deletions(-) create mode 100644 changelog.d/13736.feature (limited to 'synapse') diff --git a/changelog.d/13736.feature b/changelog.d/13736.feature new file mode 100644 index 0000000000..60a63c1009 --- /dev/null +++ b/changelog.d/13736.feature @@ -0,0 +1 @@ +Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind). diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index a09aaf3448..2db2a04f95 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from urllib.parse import urlparse from pydantic import StrictBool, StrictStr, constr +from typing_extensions import Literal from twisted.web.server import Request @@ -43,6 +44,7 @@ from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.rest.client.models import ( AuthenticationData, + ClientSecretStr, EmailRequestTokenBody, MsisdnRequestTokenBody, ) @@ -627,6 +629,11 @@ class ThreepidAddRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + class PostBody(RequestBodyModel): + auth: Optional[AuthenticationData] = None + client_secret: ClientSecretStr + sid: StrictStr + @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_3pid_changes: @@ -636,22 +643,17 @@ class ThreepidAddRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "sid"]) - sid = body["sid"] - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) + body = parse_and_validate_json_object_from_request(request, self.PostBody) await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(exclude_unset=True), "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid + body.client_secret, body.sid ) if validation_session: await self.auth_handler.add_threepid( @@ -676,23 +678,20 @@ class ThreepidBindRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_json_object_from_request(request) + class PostBody(RequestBodyModel): + client_secret: ClientSecretStr + id_access_token: StrictStr + id_server: StrictStr + sid: StrictStr - assert_params_in_dict( - body, ["id_server", "sid", "id_access_token", "client_secret"] - ) - id_server = body["id_server"] - sid = body["sid"] - id_access_token = body["id_access_token"] - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + body = parse_and_validate_json_object_from_request(request, self.PostBody) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() await self.identity_handler.bind_threepid( - client_secret, sid, user_id, id_server, id_access_token + body.client_secret, body.sid, user_id, body.id_server, body.id_access_token ) return 200, {} @@ -708,23 +707,27 @@ class ThreepidUnbindRestServlet(RestServlet): self.auth = hs.get_auth() self.datastore = self.hs.get_datastores().main + class PostBody(RequestBodyModel): + address: StrictStr + id_server: Optional[StrictStr] = None + medium: Literal["email", "msisdn"] + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ requester = await self.auth.get_user_by_req(request) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) - - medium = body.get("medium") - address = body.get("address") - id_server = body.get("id_server") + body = parse_and_validate_json_object_from_request(request, self.PostBody) # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), - {"address": address, "medium": medium, "id_server": id_server}, + { + "address": body.address, + "medium": body.medium, + "id_server": body.id_server, + }, ) return 200, {"id_server_unbind_result": "success" if result else "no-support"} @@ -738,21 +741,25 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + class PostBody(RequestBodyModel): + address: StrictStr + id_server: Optional[StrictStr] = None + medium: Literal["email", "msisdn"] + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) + body = parse_and_validate_json_object_from_request(request, self.PostBody) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: ret = await self.auth_handler.delete_threepid( - user_id, body["medium"], body["address"], body.get("id_server") + user_id, body.medium, body.address, body.id_server ) except Exception: # NB. This endpoint should succeed if there is nothing to diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py index 6278450c70..3d7940b0fc 100644 --- a/synapse/rest/client/models.py +++ b/synapse/rest/client/models.py @@ -36,18 +36,20 @@ class AuthenticationData(RequestBodyModel): type: Optional[StrictStr] = None -class ThreePidRequestTokenBody(RequestBodyModel): - if TYPE_CHECKING: - client_secret: StrictStr - else: - # See also assert_valid_client_secret() - client_secret: constr( - regex="[0-9a-zA-Z.=_-]", # noqa: F722 - min_length=0, - max_length=255, - strict=True, - ) +if TYPE_CHECKING: + ClientSecretStr = StrictStr +else: + # See also assert_valid_client_secret() + ClientSecretStr = constr( + regex="[0-9a-zA-Z.=_-]", # noqa: F722 + min_length=1, + max_length=255, + strict=True, + ) + +class ThreepidRequestTokenBody(RequestBodyModel): + client_secret: ClientSecretStr id_server: Optional[StrictStr] id_access_token: Optional[StrictStr] next_link: Optional[StrictStr] @@ -62,7 +64,7 @@ class ThreePidRequestTokenBody(RequestBodyModel): return token -class EmailRequestTokenBody(ThreePidRequestTokenBody): +class EmailRequestTokenBody(ThreepidRequestTokenBody): email: StrictStr # Canonicalise the email address. The addresses are all stored canonicalised @@ -80,6 +82,6 @@ else: ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True) -class MsisdnRequestTokenBody(ThreePidRequestTokenBody): +class MsisdnRequestTokenBody(ThreepidRequestTokenBody): country: ISO3116_1_Alpha_2 phone_number: StrictStr diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py index a9da00665e..0b8fcb0c47 100644 --- a/tests/rest/client/test_models.py +++ b/tests/rest/client/test_models.py @@ -11,14 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import unittest as stdlib_unittest -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError +from typing_extensions import Literal from synapse.rest.client.models import EmailRequestTokenBody -class EmailRequestTokenBodyTestCase(unittest.TestCase): +class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase): + class Model(BaseModel): + medium: Literal["email", "msisdn"] + + def test_accepts_valid_medium_string(self) -> None: + """Sanity check that Pydantic behaves sensibly with an enum-of-str + + This is arguably more of a test of a class that inherits from str and Enum + simultaneously. + """ + model = self.Model.parse_obj({"medium": "email"}) + self.assertEqual(model.medium, "email") + + def test_rejects_invalid_medium_value(self) -> None: + with self.assertRaises(ValidationError): + self.Model.parse_obj({"medium": "interpretive_dance"}) + + def test_rejects_invalid_medium_type(self) -> None: + with self.assertRaises(ValidationError): + self.Model.parse_obj({"medium": 123}) + + +class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase): base_request = { "client_secret": "hunter2", "email": "alice@wonderland.com", -- cgit 1.5.1 From b2b0c8527957d89b36c0eafea70347c200c1d294 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 15 Sep 2022 14:28:48 -0400 Subject: Support providing an index predicate for upserts. (#13822) This is useful to upsert against a table which has a unique partial index while avoiding conflicts. --- changelog.d/13822.misc | 1 + synapse/storage/background_updates.py | 1 + synapse/storage/database.py | 30 +++++++++++++++++++++++------- 3 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 changelog.d/13822.misc (limited to 'synapse') diff --git a/changelog.d/13822.misc b/changelog.d/13822.misc new file mode 100644 index 0000000000..dbc77cbcfa --- /dev/null +++ b/changelog.d/13822.misc @@ -0,0 +1 @@ +Support providing an index predicate clause when doing upserts. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index cf1eabc437..bf5e7ee7be 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -533,6 +533,7 @@ class BackgroundUpdater: index_name: name of index to add table: table to add index to columns: columns/expressions to include in index + where_clause: A WHERE clause to specify a partial unique index. unique: true to make a UNIQUE index psql_only: true to only create this index on psql databases (useful for virtual sqlite tables) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e881bff7fb..921cd4dc5e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1191,6 +1191,7 @@ class DatabasePool: keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, lock: bool = True, ) -> bool: """ @@ -1203,6 +1204,7 @@ class DatabasePool: keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. lock: True to lock the table when doing the upsert. Unused when performing a native upsert. Returns: @@ -1213,7 +1215,12 @@ class DatabasePool: if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + where_clause=where_clause, ) else: return self.simple_upsert_txn_emulated( @@ -1222,6 +1229,7 @@ class DatabasePool: keyvalues, values, insertion_values=insertion_values, + where_clause=where_clause, lock=lock, ) @@ -1232,6 +1240,7 @@ class DatabasePool: keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, lock: bool = True, ) -> bool: """ @@ -1240,6 +1249,7 @@ class DatabasePool: keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. lock: True to lock the table when doing the upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is @@ -1259,14 +1269,17 @@ class DatabasePool: else: return "%s = ?" % (key,) + # Generate a where clause of each keyvalue and optionally the provided + # index predicate. + where = [_getwhere(k) for k in keyvalues] + if where_clause: + where.append(where_clause) + if not values: # If `values` is empty, then all of the values we care about are in # the unique key, so there is nothing to UPDATE. We can just do a # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) + sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where)) sqlargs = list(keyvalues.values()) txn.execute(sql, sqlargs) if txn.fetchall(): @@ -1277,7 +1290,7 @@ class DatabasePool: sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), + " AND ".join(where), ) sqlargs = list(values.values()) + list(keyvalues.values()) @@ -1307,6 +1320,7 @@ class DatabasePool: keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, ) -> bool: """ Use the native UPSERT functionality in PostgreSQL. @@ -1316,6 +1330,7 @@ class DatabasePool: keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is @@ -1331,11 +1346,12 @@ class DatabasePool: allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), + f"WHERE {where_clause}" if where_clause else "", latter, ) txn.execute(sql, list(allvalues.values())) -- cgit 1.5.1 From 140af0cdb653bc2fef9474af06a5c5b525073998 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 15 Sep 2022 14:40:49 -0500 Subject: Record any exception when processing a pulled event (#13814) Part of https://github.com/matrix-org/synapse/issues/13700 and https://github.com/matrix-org/synapse/issues/13356 Follow-up to https://github.com/matrix-org/synapse/pull/13589 --- changelog.d/13589.feature | 2 +- changelog.d/13814.feature | 1 + synapse/handlers/federation_event.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13814.feature (limited to 'synapse') diff --git a/changelog.d/13589.feature b/changelog.d/13589.feature index 78fa1ddb52..a5ea2bc82e 100644 --- a/changelog.d/13589.feature +++ b/changelog.d/13589.feature @@ -1 +1 @@ -Keep track when we attempt to backfill an event but fail so we can intelligently back-off in the future. +Keep track when we fail to process a pulled event over federation so we can intelligently back-off in the future. diff --git a/changelog.d/13814.feature b/changelog.d/13814.feature new file mode 100644 index 0000000000..a5ea2bc82e --- /dev/null +++ b/changelog.d/13814.feature @@ -0,0 +1 @@ +Keep track when we fail to process a pulled event over federation so we can intelligently back-off in the future. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9e065e1116..efcdb84057 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -866,6 +866,11 @@ class FederationEventHandler: event.room_id, event_id, str(err) ) return + except Exception as exc: + await self._store.record_event_failed_pull_attempt( + event.room_id, event_id, str(exc) + ) + raise exc try: try: @@ -908,6 +913,11 @@ class FederationEventHandler: logger.warning("Pulled event %s failed history check.", event_id) else: raise + except Exception as exc: + await self._store.record_event_failed_pull_attempt( + event.room_id, event_id, str(exc) + ) + raise exc @trace async def _compute_event_context_with_maybe_missing_prevs( -- cgit 1.5.1 From 5093cbf88da1c439f5bf16b7a4cf19246781bd93 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 15 Sep 2022 15:32:25 -0500 Subject: Be able to correlate timeouts in reverse-proxy layer in front of Synapse (pull request ID from header) (#13801) Fix https://github.com/matrix-org/synapse/issues/13685 New config: ```diff listeners: - port: 8008 tls: false type: http x_forwarded: true + request_id_header: "cf-ray" bind_addresses: ['::1', '127.0.0.1', '0.0.0.0'] ``` --- changelog.d/13801.feature | 1 + docs/reverse_proxy.md | 4 ++++ docs/usage/configuration/config_documentation.md | 11 ++++++++++- synapse/config/server.py | 13 ++++++++++--- synapse/http/site.py | 14 +++++++++++++- 5 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 changelog.d/13801.feature (limited to 'synapse') diff --git a/changelog.d/13801.feature b/changelog.d/13801.feature new file mode 100644 index 0000000000..d7cedfd302 --- /dev/null +++ b/changelog.d/13801.feature @@ -0,0 +1 @@ +Add `listeners[x].request_id_header` config to specify which request header to extract and use as the request ID in order to correlate requests from a reverse-proxy. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index d1618e8155..4e7a1d4435 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -45,6 +45,10 @@ listens to traffic on localhost. (Do not change `bind_addresses` to `127.0.0.1` when using a containerized Synapse, as that will prevent it from responding to proxied traffic.) +Optionally, you can also set +[`request_id_header`](../usage/configuration/config_documentation.md#listeners) +so that the server extracts and re-uses the same request ID format that the +reverse proxy is using. ## Reverse-proxy configuration examples diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index cd546041b2..69d305b62e 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -434,7 +434,16 @@ Sub-options for each listener include: * `tls`: set to true to enable TLS for this listener. Will use the TLS key/cert specified in tls_private_key_path / tls_certificate_path. * `x_forwarded`: Only valid for an 'http' listener. Set to true to use the X-Forwarded-For header as the client IP. Useful when Synapse is - behind a reverse-proxy. + behind a [reverse-proxy](../../reverse_proxy.md). + +* `request_id_header`: The header extracted from each incoming request that is + used as the basis for the request ID. The request ID is used in + [logs](../administration/request_log.md#request-log-format) and tracing to + correlate and match up requests. When unset, Synapse will automatically + generate sequential request IDs. This option is useful when Synapse is behind + a [reverse-proxy](../../reverse_proxy.md). + + _Added in Synapse 1.68.0._ * `resources`: Only valid for an 'http' listener. A list of resources to host on this port. Sub-options for each resource are: diff --git a/synapse/config/server.py b/synapse/config/server.py index c91df636d9..f2353ce5fb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -206,6 +206,7 @@ class HttpListenerConfig: resources: List[HttpResourceConfig] = attr.Factory(list) additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None + request_id_header: Optional[str] = None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -520,9 +521,11 @@ class ServerConfig(Config): ): raise ConfigError("allowed_avatar_mimetypes must be a list") - self.listeners = [ - parse_listener_def(i, x) for i, x in enumerate(config.get("listeners", [])) - ] + listeners = config.get("listeners", []) + if not isinstance(listeners, list): + raise ConfigError("Expected a list", ("listeners",)) + + self.listeners = [parse_listener_def(i, x) for i, x in enumerate(listeners)] # no_tls is not really supported any more, but let's grandfather it in # here. @@ -889,6 +892,9 @@ def read_gc_thresholds( def parse_listener_def(num: int, listener: Any) -> ListenerConfig: """parse a listener config from the config file""" + if not isinstance(listener, dict): + raise ConfigError("Expected a dictionary", ("listeners", str(num))) + listener_type = listener["type"] # Raise a helpful error if direct TCP replication is still configured. if listener_type == "replication": @@ -928,6 +934,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: resources=resources, additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), + request_id_header=listener.get("request_id_header"), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/http/site.py b/synapse/http/site.py index 1155f3f610..55a6afce35 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -72,10 +72,12 @@ class SynapseRequest(Request): site: "SynapseSite", *args: Any, max_request_body_size: int = 1024, + request_id_header: Optional[str] = None, **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size + self.request_id_header = request_id_header self.synapse_site = site self.reactor = site.reactor self._channel = channel # this is used by the tests @@ -172,7 +174,14 @@ class SynapseRequest(Request): self._opentracing_span = span def get_request_id(self) -> str: - return "%s-%i" % (self.get_method(), self.request_seq) + request_id_value = None + if self.request_id_header: + request_id_value = self.getHeader(self.request_id_header) + + if request_id_value is None: + request_id_value = str(self.request_seq) + + return "%s-%s" % (self.get_method(), request_id_value) def get_redacted_uri(self) -> str: """Gets the redacted URI associated with the request (or placeholder if the URI @@ -611,12 +620,15 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest + request_id_header = config.http_options.request_id_header + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, max_request_body_size=max_request_body_size, queued=queued, + request_id_header=request_id_header, ) self.requestFactory = request_factory # type: ignore -- cgit 1.5.1 From b73cbb82157d9666e8d667733afebc0d09ed858c Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:45:04 +0100 Subject: Avoid putting rejected events in room state (#13723) Signed-off-by: Sean Quah --- changelog.d/13723.bugfix | 1 + synapse/state/v2.py | 15 ++ tests/handlers/test_federation_event.py | 399 ++++++++++++++++++++++++++++++++ 3 files changed, 415 insertions(+) create mode 100644 changelog.d/13723.bugfix (limited to 'synapse') diff --git a/changelog.d/13723.bugfix b/changelog.d/13723.bugfix new file mode 100644 index 0000000000..a23174d31d --- /dev/null +++ b/changelog.d/13723.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where previously rejected events could end up in room state because they pass auth checks given the current state of the room. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index af03851c71..1b9d7d8457 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -577,6 +577,21 @@ async def _iterative_auth_checks( if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] + if event.rejected_reason is not None: + # Do not admit previously rejected events into state. + # TODO: This isn't spec compliant. Events that were previously rejected due + # to failing auth checks at their state, but pass auth checks during + # state resolution should be accepted. Synapse does not handle the + # change of rejection status well, so we preserve the previous + # rejection status for now. + # + # Note that events rejected for non-state reasons, such as having the + # wrong auth events, should remain rejected. + # + # https://spec.matrix.org/v1.2/rooms/v9/#rejected-events + # https://github.com/matrix-org/synapse/issues/13797 + continue + try: event_auth.check_state_dependent_auth_rules( event, diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index b5b89405a4..918010cddb 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -11,14 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest import mock +from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersion +from synapse.event_auth import ( + check_state_dependent_auth_rules, + check_state_independent_auth_rules, +) from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.federation.transport.client import StateRequestResponse from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort +from synapse.types import JsonDict from tests import unittest from tests.test_utils import event_injection, make_awaitable @@ -449,3 +458,393 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): main_store.get_event(pulled_event.event_id, allow_none=True) ) self.assertIsNotNone(persisted, "pulled event was not persisted at all") + + def test_process_pulled_event_with_rejected_missing_state(self) -> None: + """Ensure that we correctly handle pulled events with missing state containing a + rejected state event + + In this test, we pretend we are processing a "pulled" event (eg, via backfill + or get_missing_events). The pulled event has a prev_event we haven't previously + seen, so the server requests the state at that prev_event. We expect the server + to make a /state request. + + We simulate a remote server whose /state includes a rejected kick event for a + local user. Notably, the kick event is rejected only because it cites a rejected + auth event and would otherwise be accepted based on the room state. During state + resolution, we re-run auth and can potentially introduce such rejected events + into the state if we are not careful. + + We check that the pulled event is correctly persisted, and that the state + afterwards does not include the rejected kick. + """ + # The DAG we are testing looks like: + # + # ... + # | + # v + # remote admin user joins + # | | + # +-------+ +-------+ + # | | + # | rejected power levels + # | from remote server + # | | + # | v + # | rejected kick of local user + # v from remote server + # new power levels | + # | v + # | missing event + # | from remote server + # | | + # +-------+ +-------+ + # | | + # v v + # pulled event + # from remote server + # + # (arrows are in the opposite direction to prev_events.) + + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room. + kermit_user_id = self.register_user("kermit", "test") + kermit_tok = self.login("kermit", "test") + room_id = self.helper.create_room_as( + room_creator=kermit_user_id, tok=kermit_tok + ) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Add another local user to the room. This user is going to be kicked in a + # rejected event. + bert_user_id = self.register_user("bert", "test") + bert_tok = self.login("bert", "test") + self.helper.join(room_id, user=bert_user_id, tok=bert_tok) + + # Allow the remote user to kick bert. + # The remote user is going to send a rejected power levels event later on and we + # need state resolution to order it before another power levels event kermit is + # going to send later on. Hence we give both users the same power level, so that + # ties are broken by `origin_server_ts`. + self.helper.send_state( + room_id, + "m.room.power_levels", + {"users": {kermit_user_id: 100, OTHER_USER: 100}}, + tok=kermit_tok, + ) + + # Add the remote user to the room. + other_member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + create_event = self.get_success( + main_store.get_event(initial_state_map[("m.room.create", "")]) + ) + bert_member_event = self.get_success( + main_store.get_event(initial_state_map[("m.room.member", bert_user_id)]) + ) + power_levels_event = self.get_success( + main_store.get_event(initial_state_map[("m.room.power_levels", "")]) + ) + + # We now need a rejected state event that will fail + # `check_state_independent_auth_rules` but pass + # `check_state_dependent_auth_rules`. + + # First, we create a power levels event that we pretend the remote server has + # accepted, but the local homeserver will reject. + next_depth = 100 + next_timestamp = other_member_event.origin_server_ts + 100 + rejected_power_levels_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.power_levels", + "state_key": "", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [other_member_event.event_id], + "auth_events": [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + # The event will be rejected because of the duplicated auth + # event. + other_member_event.event_id, + other_member_event.event_id, + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": power_levels_event.content, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + with LoggingContext("send_rejected_power_levels_event"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, + rejected_power_levels_event, + backfilled=False, + ) + ) + self.assertEqual( + self.get_success( + main_store.get_rejection_reason( + rejected_power_levels_event.event_id + ) + ), + "auth_error", + ) + + # Then we create a kick event for a local user that cites the rejected power + # levels event in its auth events. The kick event will be rejected solely + # because of the rejected auth event and would otherwise be accepted. + rejected_kick_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.member", + "state_key": bert_user_id, + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [rejected_power_levels_event.event_id], + "auth_events": [ + initial_state_map[("m.room.create", "")], + rejected_power_levels_event.event_id, + initial_state_map[("m.room.member", bert_user_id)], + initial_state_map[("m.room.member", OTHER_USER)], + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": {"membership": "leave"}, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + # The kick event must fail the state-independent auth rules, but pass the + # state-dependent auth rules, so that it has a chance of making it through state + # resolution. + self.get_failure( + check_state_independent_auth_rules(main_store, rejected_kick_event), + AuthError, + ) + check_state_dependent_auth_rules( + rejected_kick_event, + [create_event, power_levels_event, other_member_event, bert_member_event], + ) + + # The kick event must also win over the original member event during state + # resolution. + self.assertEqual( + self.get_success( + _mainline_sort( + self.clock, + room_id, + event_ids=[ + bert_member_event.event_id, + rejected_kick_event.event_id, + ], + resolved_power_event_id=power_levels_event.event_id, + event_map={ + bert_member_event.event_id: bert_member_event, + rejected_kick_event.event_id: rejected_kick_event, + }, + state_res_store=main_store, + ) + ), + [bert_member_event.event_id, rejected_kick_event.event_id], + "The rejected kick event will not be applied after bert's join event " + "during state resolution. The test setup is incorrect.", + ) + + with LoggingContext("send_rejected_kick_event"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False + ) + ) + self.assertEqual( + self.get_success( + main_store.get_rejection_reason(rejected_kick_event.event_id) + ), + "auth_error", + ) + + # We need another power levels event which will win over the rejected one during + # state resolution, otherwise we hit other issues where we end up with rejected + # a power levels event during state resolution. + self.reactor.advance(100) # ensure the `origin_server_ts` is larger + new_power_levels_event = self.get_success( + main_store.get_event( + self.helper.send_state( + room_id, + "m.room.power_levels", + {"users": {kermit_user_id: 100, OTHER_USER: 100, bert_user_id: 1}}, + tok=kermit_tok, + )["event_id"] + ) + ) + self.assertEqual( + self.get_success( + _reverse_topological_power_sort( + self.clock, + room_id, + event_ids=[ + new_power_levels_event.event_id, + rejected_power_levels_event.event_id, + ], + event_map={}, + state_res_store=main_store, + full_conflicted_set=set(), + ) + ), + [rejected_power_levels_event.event_id, new_power_levels_event.event_id], + "The power levels events will not have the desired ordering during state " + "resolution. The test setup is incorrect.", + ) + + # Create a missing event, so that the local homeserver has to do a `/state` or + # `/state_ids` request to pull state from the remote homeserver. + missing_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.message", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [rejected_kick_event.event_id], + "auth_events": [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + initial_state_map[("m.room.member", OTHER_USER)], + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": {"msgtype": "m.text", "body": "foo"}, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + # The pulled event has two prev events, one of which is missing. We will make a + # `/state` or `/state_ids` request to the remote homeserver to ask it for the + # state before the missing prev event. + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "m.room.message", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + new_power_levels_event.event_id, + missing_event.event_id, + ], + "auth_events": [ + initial_state_map[("m.room.create", "")], + new_power_levels_event.event_id, + initial_state_map[("m.room.member", OTHER_USER)], + ], + "origin_server_ts": next_timestamp, + "depth": next_depth, + "content": {"msgtype": "m.text", "body": "bar"}, + } + ), + room_version, + ) + next_depth += 1 + next_timestamp += 100 + + # Prepare the response for the `/state` or `/state_ids` request. + # The remote server believes bert has been kicked, while the local server does + # not. + state_before_missing_event = self.get_success( + main_store.get_events_as_list(initial_state_map.values()) + ) + state_before_missing_event = [ + event + for event in state_before_missing_event + if event.event_id != bert_member_event.event_id + ] + state_before_missing_event.append(rejected_kick_event) + + # We have to bump the clock a bit, to keep the retry logic in + # `FederationClient.get_pdu` happy + self.reactor.advance(60000) + with LoggingContext("send_pulled_event"): + + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> JsonDict: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, missing_event.event_id) + return {"pdus": [missing_event.get_pdu_json()]} + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> JsonDict: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, missing_event.event_id) + return { + "pdu_ids": [event.event_id for event in state_before_missing_event], + "auth_chain_ids": [], + } + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> StateRequestResponse: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, missing_event.event_id) + return StateRequestResponse( + state=state_before_missing_event, + auth_events=[], + ) + + self.mock_federation_transport_client.get_event.side_effect = get_event + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, pulled_event, backfilled=False + ) + ) + self.assertIsNone( + self.get_success( + main_store.get_rejection_reason(pulled_event.event_id) + ), + "Pulled event was unexpectedly rejected, likely due to a problem with " + "the test setup.", + ) + self.assertEqual( + {pulled_event.event_id}, + self.get_success( + main_store.have_events_in_timeline([pulled_event.event_id]) + ), + "Pulled event was not persisted, likely due to a problem with the test " + "setup.", + ) + + # We must not accept rejected events into the room state, so we expect bert + # to not be kicked, even if the remote server believes so. + new_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + self.assertEqual( + new_state_map[("m.room.member", bert_user_id)], + bert_member_event.event_id, + "Rejected kick event unexpectedly became part of room state.", + ) -- cgit 1.5.1 From 74f60cec92c5aff87d6e74d177e95ec5f1a69f2b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 16 Sep 2022 14:29:03 +0200 Subject: Add an admin API endpoint to find a user based on its external ID in an auth provider. (#13810) --- changelog.d/13810.feature | 1 + docs/admin_api/user_admin_api.md | 38 ++++++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 27 +++++++++++++ tests/rest/admin/test_user.py | 87 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 155 insertions(+) create mode 100644 changelog.d/13810.feature (limited to 'synapse') diff --git a/changelog.d/13810.feature b/changelog.d/13810.feature new file mode 100644 index 0000000000..f0258af661 --- /dev/null +++ b/changelog.d/13810.feature @@ -0,0 +1 @@ +Add an admin API endpoint to find a user based on its external ID in an auth provider. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 975f05c929..3625c7b6c5 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1155,3 +1155,41 @@ GET /_synapse/admin/v1/username_available?username=$localpart The request and response format is the same as the [/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API. + +### Find a user based on their ID in an auth provider + +The API is: + +``` +GET /_synapse/admin/v1/auth_providers/$provider/users/$external_id +``` + +When a user matched the given ID for the given provider, an HTTP code `200` with a response body like the following is returned: + +```json +{ + "user_id": "@hello:example.org" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `provider` - The ID of the authentication provider, as advertised by the [`GET /_matrix/client/v3/login`](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3login) API in the `m.login.sso` authentication method. +- `external_id` - The user ID from the authentication provider. Usually corresponds to the `sub` claim for OIDC providers, or to the `uid` attestation for SAML2 providers. + +The `external_id` may have characters that are not URL-safe (typically `/`, `:` or `@`), so it is advised to URL-encode those parameters. + +**Errors** + +Returns a `404` HTTP status code if no user was found, with a response body like this: + +```json +{ + "errcode":"M_NOT_FOUND", + "error":"User not found" +} +``` + +_Added in Synapse 1.68.0._ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index bac754e1b1..885669f9c7 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -80,6 +80,7 @@ from synapse.rest.admin.users import ( SearchUsersRestServlet, ShadowBanRestServlet, UserAdminServlet, + UserByExternalId, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -275,6 +276,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListDestinationsRestServlet(hs).register(http_server) RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) + UserByExternalId(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 78ee9b6532..2ca6b2d08a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1156,3 +1156,30 @@ class AccountDataRestServlet(RestServlet): "rooms": by_room_data, }, } + + +class UserByExternalId(RestServlet): + """Find a user based on an external ID from an auth provider""" + + PATTERNS = admin_patterns( + "/auth_providers/(?P[^/]*)/users/(?P[^/]*)" + ) + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, + request: SynapseRequest, + provider: str, + external_id: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + user_id = await self._store.get_user_by_external_id(provider, external_id) + + if user_id is None: + raise NotFoundError("User not found") + + return HTTPStatus.OK, {"user_id": user_id} diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ec5ccf6fca..9f536ceeb3 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -4140,3 +4140,90 @@ class AccountDataTestCase(unittest.HomeserverTestCase): {"b": 2}, channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"], ) + + +class UsersByExternalIdTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.get_success( + self.store.record_user_external_id( + "the-auth-provider", "the-external-id", self.other_user + ) + ) + self.get_success( + self.store.record_user_external_id( + "another-auth-provider", "a:complex@external/id", self.other_user + ) + ) + + def test_no_auth(self) -> None: + """Try to lookup a user without authentication.""" + url = ( + "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id" + ) + + channel = self.make_request( + "GET", + url, + ) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_binding_does_not_exist(self) -> None: + """Tests that a lookup for an external ID that does not exist returns a 404""" + url = "/_synapse/admin/v1/auth_providers/the-auth-provider/users/unknown-id" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_success(self) -> None: + """Tests a successful external ID lookup""" + url = ( + "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id" + ) + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) + + def test_success_urlencoded(self) -> None: + """Tests a successful external ID lookup with an url-encoded ID""" + url = "/_synapse/admin/v1/auth_providers/another-auth-provider/users/a%3Acomplex%40external%2Fid" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) -- cgit 1.5.1 From d64e85197af31f5642f64ae1d86f5a0c74050fec Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 16 Sep 2022 16:16:05 +0100 Subject: Remove error spam when users query the keys of departed remote users (#13826) The error message introduced in #13749 has turned out to be very spammy. Remove it for now. --- changelog.d/13826.bugfix | 1 + synapse/handlers/e2e_keys.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 changelog.d/13826.bugfix (limited to 'synapse') diff --git a/changelog.d/13826.bugfix b/changelog.d/13826.bugfix new file mode 100644 index 0000000000..8ffafec07b --- /dev/null +++ b/changelog.d/13826.bugfix @@ -0,0 +1 @@ +Fix a long standing bug where device lists would remain cached when remote users left and rejoined the last room shared with the local homeserver. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8eed63ccf3..09a2492afc 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -188,18 +188,21 @@ class E2eKeysHandler: ) invalid_cached_users = cached_users - valid_cached_users if invalid_cached_users: - # Fix up results. If we get here, there is either a bug in device - # list tracking, or we hit the race mentioned above. + # Fix up results. If we get here, it means there was either a bug in + # device list tracking, or we hit the race mentioned above. + # TODO: In practice, this path is hit fairly often in existing + # deployments when clients query the keys of departed remote + # users. A background update to mark the appropriate device + # lists as unsubscribed is needed. + # https://github.com/matrix-org/synapse/issues/13651 + # Note that this currently introduces a failure mode when clients + # are trying to decrypt old messages from a remote user whose + # homeserver is no longer available. We may want to consider falling + # back to the cached data when we fail to retrieve a device list + # over federation for such remote users. user_ids_not_in_cache.update(invalid_cached_users) for invalid_user_id in invalid_cached_users: remote_results.pop(invalid_user_id) - # This log message may be removed if it turns out it's almost - # entirely triggered by races. - logger.error( - "Devices for %s were cached, but the server no longer shares " - "any rooms with them. The cached device lists are stale.", - invalid_cached_users, - ) for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) -- cgit 1.5.1 From 44be42338e032a50e5fc3d6c69be4055f33cb26c Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 16 Sep 2022 10:56:56 -0500 Subject: Add support to purge rows from MSC2716 and other tables when purging a room (#13825) `event_failed_pull_attempts` added in https://github.com/matrix-org/synapse/pull/13589 MSC2716 related tables added in: - https://github.com/matrix-org/synapse/pull/10245/files#diff-3d42dfb44d02f7de3aada105e0bdc1cc9dd7f953cbf0f36c5d0f50827bf0320aR1 - Renamed in https://github.com/matrix-org/synapse/pull/10838/files#diff-2730bfbe9e688b55e46f9371aefe67dac2bd2b2b7d9d6b92774eea1fcfae156dR1 - https://github.com/matrix-org/synapse/pull/10498/files#diff-c52bbfbb5921a3f6f023b24343668479d966fac164f13b7c39d2197ce3afa7a5R1 --- changelog.d/13825.bugfix | 1 + synapse/storage/databases/main/purge_events.py | 5 +++++ synapse/storage/schema/__init__.py | 2 ++ .../delta/73/02room_id_indexes_for_purging.sql | 22 ++++++++++++++++++++++ 4 files changed, 30 insertions(+) create mode 100644 changelog.d/13825.bugfix create mode 100644 synapse/storage/schema/main/delta/73/02room_id_indexes_for_purging.sql (limited to 'synapse') diff --git a/changelog.d/13825.bugfix b/changelog.d/13825.bugfix new file mode 100644 index 0000000000..626fc6349f --- /dev/null +++ b/changelog.d/13825.bugfix @@ -0,0 +1 @@ +Delete associated data from `event_failed_pull_attempts`, `insertion_events`, `insertion_event_extremities`, `insertion_event_extremities`, `insertion_event_extremities` when purging the room. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index f6822707e4..9213ce0b5a 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -419,6 +419,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_forward_extremities", "event_push_actions", "event_search", + "event_failed_pull_attempts", "partial_state_events", "events", "federation_inbound_events_staging", @@ -441,6 +442,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "e2e_room_keys", "event_push_summary", "pusher_throttle", + "insertion_events", + "insertion_event_extremities", + "insertion_event_edges", + "batch_events", "room_account_data", "room_tags", # "rooms" happens last, to keep the foreign keys in the other tables diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 68e055c664..f29424d17a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -83,6 +83,8 @@ Changes in SCHEMA_VERSION = 73; event_push_summary, receipts_linearized, and receipts_graph. - Add table `event_failed_pull_attempts` to keep track when we fail to pull events over federation. + - Add indexes to various tables (`event_failed_pull_attempts`, `insertion_events`, + `batch_events`) to make it easy to delete all associated rows when purging a room. """ diff --git a/synapse/storage/schema/main/delta/73/02room_id_indexes_for_purging.sql b/synapse/storage/schema/main/delta/73/02room_id_indexes_for_purging.sql new file mode 100644 index 0000000000..6d38bdd430 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/02room_id_indexes_for_purging.sql @@ -0,0 +1,22 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add index so we can easily purge all rows from a given `room_id` +CREATE INDEX IF NOT EXISTS event_failed_pull_attempts_room_id ON event_failed_pull_attempts(room_id); + +-- MSC2716 related tables: +-- Add indexes so we can easily purge all rows from a given `room_id` +CREATE INDEX IF NOT EXISTS insertion_events_room_id ON insertion_events(room_id); +CREATE INDEX IF NOT EXISTS batch_events_room_id ON batch_events(room_id); -- cgit 1.5.1 From c802ef14119b21cfdf8f5a9c246b695c98c0f718 Mon Sep 17 00:00:00 2001 From: Denis Date: Tue, 20 Sep 2022 10:44:38 +0200 Subject: Don't include redundant prev_state in new events (#13791) --- changelog.d/13791.removal | 1 + synapse/events/builder.py | 1 - synapse/federation/federation_client.py | 3 --- 3 files changed, 1 insertion(+), 4 deletions(-) create mode 100644 changelog.d/13791.removal (limited to 'synapse') diff --git a/changelog.d/13791.removal b/changelog.d/13791.removal new file mode 100644 index 0000000000..283226b63e --- /dev/null +++ b/changelog.d/13791.removal @@ -0,0 +1 @@ +Don't include redundant `prev_state` in new events. Contributed by Denis Kariakin (@dakariakin). diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 746bd3978d..e2ee10dd3d 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -167,7 +167,6 @@ class EventBuilder: "content": self.content, "unsigned": self.unsigned, "depth": depth, - "prev_state": [], } if self.is_state(): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4a4289ee7c..464672a3da 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -906,9 +906,6 @@ class FederationClient(FederationBase): # The protoevent received over the JSON wire may not have all # the required fields. Lets just gloss over that because # there's some we never care about - if "prev_state" not in pdu_dict: - pdu_dict["prev_state"] = [] - ev = builder.create_local_event_from_event_dict( self._clock, self.hostname, -- cgit 1.5.1 From 42d261c32f13e2de7494a0ade77c1f7b646af1fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Sep 2022 12:10:31 +0100 Subject: Port the push rule classes to Rust. (#13768) --- .rustfmt.toml | 1 + changelog.d/13768.misc | 1 + rust/Cargo.toml | 10 +- rust/src/lib.rs | 9 +- rust/src/push/base_rules.rs | 335 ++++++++++++++++ rust/src/push/mod.rs | 502 ++++++++++++++++++++++++ stubs/synapse/synapse_rust.pyi | 2 - stubs/synapse/synapse_rust/__init__.pyi | 2 + stubs/synapse/synapse_rust/push.pyi | 37 ++ synapse/handlers/push_rules.py | 5 +- synapse/push/baserules.py | 583 ---------------------------- synapse/push/bulk_push_rule_evaluator.py | 7 +- synapse/push/clientformat.py | 5 +- synapse/storage/databases/main/push_rule.py | 23 +- tests/handlers/test_deactivate_account.py | 27 +- 15 files changed, 932 insertions(+), 617 deletions(-) create mode 100644 .rustfmt.toml create mode 100644 changelog.d/13768.misc create mode 100644 rust/src/push/base_rules.rs create mode 100644 rust/src/push/mod.rs delete mode 100644 stubs/synapse/synapse_rust.pyi create mode 100644 stubs/synapse/synapse_rust/__init__.pyi create mode 100644 stubs/synapse/synapse_rust/push.pyi delete mode 100644 synapse/push/baserules.py (limited to 'synapse') diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000000..bf96e7743d --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +group_imports = "StdExternalCrate" diff --git a/changelog.d/13768.misc b/changelog.d/13768.misc new file mode 100644 index 0000000000..28bddb7059 --- /dev/null +++ b/changelog.d/13768.misc @@ -0,0 +1 @@ +Port push rules to using Rust. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index deddf3cec2..8dc5f93ff1 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -18,7 +18,15 @@ crate-type = ["cdylib"] name = "synapse.synapse_rust" [dependencies] -pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] } +anyhow = "1.0.63" +lazy_static = "1.4.0" +log = "0.4.17" +pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] } +pyo3-log = "0.7.0" +pythonize = "0.17.0" +regex = "1.6.0" +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0.85" [build-dependencies] blake2 = "0.10.4" diff --git a/rust/src/lib.rs b/rust/src/lib.rs index ba42465fb8..c7b60e58a7 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,5 +1,7 @@ use pyo3::prelude::*; +pub mod push; + /// Returns the hash of all the rust source files at the time it was compiled. /// /// Used by python to detect if the rust library is outdated. @@ -17,8 +19,13 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { /// The entry point for defining the Python module. #[pymodule] -fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { + pyo3_log::init(); + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; + + push::register_module(py, m)?; + Ok(()) } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs new file mode 100644 index 0000000000..7c62bc4849 --- /dev/null +++ b/rust/src/push/base_rules.rs @@ -0,0 +1,335 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Contains the definitions of the "base" push rules. + +use std::borrow::Cow; +use std::collections::HashMap; + +use lazy_static::lazy_static; +use serde_json::Value; + +use super::KnownCondition; +use crate::push::Action; +use crate::push::Condition; +use crate::push::EventMatchCondition; +use crate::push::PushRule; +use crate::push::SetTweak; +use crate::push::TweakValue; + +const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("highlight"), + value: None, + other_keys: Value::Null, +}); + +const HIGHLIGHT_FALSE_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("highlight"), + value: Some(TweakValue::Other(Value::Bool(false))), + other_keys: Value::Null, +}); + +const SOUND_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("sound"), + value: Some(TweakValue::String(Cow::Borrowed("default"))), + other_keys: Value::Null, +}); + +const RING_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("sound"), + value: Some(TweakValue::String(Cow::Borrowed("ring"))), + other_keys: Value::Null, +}); + +pub const BASE_PREPEND_OVERRIDE_RULES: &[PushRule] = &[PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.master"), + priority_class: 5, + conditions: Cow::Borrowed(&[]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: false, +}]; + +pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.suppress_notices"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("content.msgtype"), + pattern: Some(Cow::Borrowed("m.notice")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.invite_for_me"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.member")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("content.membership"), + pattern: Some(Cow::Borrowed("invite")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.member_event"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.member")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::ContainsDisplayName)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::SenderNotificationPermission { + key: Cow::Borrowed("room"), + }), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("content.body"), + pattern: Some(Cow::Borrowed("@room")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.tombstone"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.tombstone")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: Some(Cow::Borrowed("")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.reaction"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.reaction")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.org.matrix.msc3786.rule.room.server_acl"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.server_acl")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: Some(Cow::Borrowed("")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, +]; + +pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { + rule_id: Cow::Borrowed("global/content/.m.rule.contains_user_name"), + priority_class: 4, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("content.body"), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_localpart")), + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, +}]; + +pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.call"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.call.invite")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, RING_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.room_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.message")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.encrypted_room_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.encrypted")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { + rel_type: Cow::Borrowed("m.thread"), + sender: None, + sender_type: Some(Cow::Borrowed("user_id")), + })]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.message"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.message")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.encrypted"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.encrypted")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.im.vector.jitsi"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("im.vector.modular.widgets")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("content.type"), + pattern: Some(Cow::Borrowed("jitsi")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: Some(Cow::Borrowed("*")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, +]; + +lazy_static! { + pub static ref BASE_RULES_BY_ID: HashMap<&'static str, &'static PushRule> = + BASE_PREPEND_OVERRIDE_RULES + .iter() + .chain(BASE_APPEND_OVERRIDE_RULES.iter()) + .chain(BASE_APPEND_CONTENT_RULES.iter()) + .chain(BASE_APPEND_UNDERRIDE_RULES.iter()) + .map(|rule| { (&*rule.rule_id, rule) }) + .collect(); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs new file mode 100644 index 0000000000..de6764e7c5 --- /dev/null +++ b/rust/src/push/mod.rs @@ -0,0 +1,502 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! An implementation of Matrix push rules. +//! +//! The `Cow<_>` type is used extensively within this module to allow creating +//! the base rules as constants (in Rust constants can't require explicit +//! allocation atm). +//! +//! --- +//! +//! Push rules is the system used to determine which events trigger a push (and a +//! bump in notification counts). +//! +//! This consists of a list of "push rules" for each user, where a push rule is a +//! pair of "conditions" and "actions". When a user receives an event Synapse +//! iterates over the list of push rules until it finds one where all the conditions +//! match the event, at which point "actions" describe the outcome (e.g. notify, +//! highlight, etc). +//! +//! Push rules are split up into 5 different "kinds" (aka "priority classes"), which +//! are run in order: +//! 1. Override — highest priority rules, e.g. always ignore notices +//! 2. Content — content specific rules, e.g. @ notifications +//! 3. Room — per room rules, e.g. enable/disable notifications for all messages +//! in a room +//! 4. Sender — per sender rules, e.g. never notify for messages from a given +//! user +//! 5. Underride — the lowest priority "default" rules, e.g. notify for every +//! message. +//! +//! The set of "base rules" are the list of rules that every user has by default. A +//! user can modify their copy of the push rules in one of three ways: +//! +//! 1. Adding a new push rule of a certain kind +//! 2. Changing the actions of a base rule +//! 3. Enabling/disabling a base rule. +//! +//! The base rules are split into whether they come before or after a particular +//! kind, so the order of push rule evaluation would be: base rules for before +//! "override" kind, user defined "override" rules, base rules after "override" +//! kind, etc, etc. + +use std::borrow::Cow; +use std::collections::{BTreeMap, HashMap, HashSet}; + +use anyhow::{Context, Error}; +use log::warn; +use pyo3::prelude::*; +use pythonize::pythonize; +use serde::de::Error as _; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +mod base_rules; + +/// Called when registering modules with python. +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "push")?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.push", child_module)?; + + Ok(()) +} + +#[pyfunction] +fn get_base_rule_ids() -> HashSet<&'static str> { + base_rules::BASE_RULES_BY_ID.keys().copied().collect() +} + +/// A single push rule for a user. +#[derive(Debug, Clone)] +#[pyclass(frozen)] +pub struct PushRule { + /// A unique ID for this rule + pub rule_id: Cow<'static, str>, + /// The "kind" of push rule this is (see `PRIORITY_CLASS_MAP` in Python) + #[pyo3(get)] + pub priority_class: i32, + /// The conditions that must all match for actions to be applied + pub conditions: Cow<'static, [Condition]>, + /// The actions to apply if all conditions are met + pub actions: Cow<'static, [Action]>, + /// Whether this is a base rule + #[pyo3(get)] + pub default: bool, + /// Whether this is enabled by default + #[pyo3(get)] + pub default_enabled: bool, +} + +#[pymethods] +impl PushRule { + #[staticmethod] + pub fn from_db( + rule_id: String, + priority_class: i32, + conditions: &str, + actions: &str, + ) -> Result { + let conditions = serde_json::from_str(conditions).context("parsing conditions")?; + let actions = serde_json::from_str(actions).context("parsing actions")?; + + Ok(PushRule { + rule_id: Cow::Owned(rule_id), + priority_class, + conditions, + actions, + default: false, + default_enabled: true, + }) + } + + #[getter] + fn rule_id(&self) -> &str { + &self.rule_id + } + + #[getter] + fn actions(&self) -> Vec { + self.actions.clone().into_owned() + } + + #[getter] + fn conditions(&self) -> Vec { + self.conditions.clone().into_owned() + } + + fn __repr__(&self) -> String { + format!( + "", + self.rule_id, self.conditions, self.actions + ) + } +} + +/// The "action" Synapse should perform for a matching push rule. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Action { + DontNotify, + Notify, + Coalesce, + SetTweak(SetTweak), + + // An unrecognized custom action. + Unknown(Value), +} + +impl IntoPy for Action { + fn into_py(self, py: Python<'_>) -> PyObject { + // When we pass the `Action` struct to Python we want it to be converted + // to a dict. We use `pythonize`, which converts the struct using the + // `serde` serialization. + pythonize(py, &self).expect("valid action") + } +} + +/// The body of a `SetTweak` push action. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct SetTweak { + set_tweak: Cow<'static, str>, + + #[serde(skip_serializing_if = "Option::is_none")] + value: Option, + + // This picks up any other fields that may have been added by clients. + // These get added when we convert the `Action` to a python object. + #[serde(flatten)] + other_keys: Value, +} + +/// The value of a `set_tweak`. +/// +/// We need this (rather than using `TweakValue` directly) so that we can use +/// `&'static str` in the value when defining the constant base rules. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum TweakValue { + String(Cow<'static, str>), + Other(Value), +} + +impl Serialize for Action { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Action::DontNotify => serializer.serialize_str("dont_notify"), + Action::Notify => serializer.serialize_str("notify"), + Action::Coalesce => serializer.serialize_str("coalesce"), + Action::SetTweak(tweak) => tweak.serialize(serializer), + Action::Unknown(value) => value.serialize(serializer), + } + } +} + +/// Simple helper class for deserializing Action from JSON. +#[derive(Deserialize)] +#[serde(untagged)] +enum ActionDeserializeHelper { + Str(String), + SetTweak(SetTweak), + Unknown(Value), +} + +impl<'de> Deserialize<'de> for Action { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?; + match helper { + ActionDeserializeHelper::Str(s) => match &*s { + "dont_notify" => Ok(Action::DontNotify), + "notify" => Ok(Action::Notify), + "coalesce" => Ok(Action::Coalesce), + _ => Err(D::Error::custom("unrecognized action")), + }, + ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)), + ActionDeserializeHelper::Unknown(value) => Ok(Action::Unknown(value)), + } + } +} + +/// A condition used in push rules to match against an event. +/// +/// We need this split as `serde` doesn't give us the ability to have a +/// "catchall" variant in tagged enums. +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Condition { + /// A recognized condition that we can match against + Known(KnownCondition), + /// An unrecognized condition that we ignore. + Unknown(Value), +} + +/// The set of "known" conditions that we can handle. +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "kind")] +pub enum KnownCondition { + EventMatch(EventMatchCondition), + ContainsDisplayName, + RoomMemberCount { + #[serde(skip_serializing_if = "Option::is_none")] + is: Option>, + }, + SenderNotificationPermission { + key: Cow<'static, str>, + }, + #[serde(rename = "org.matrix.msc3772.relation_match")] + RelationMatch { + rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + sender: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + sender_type: Option>, + }, +} + +impl IntoPy for Condition { + fn into_py(self, py: Python<'_>) -> PyObject { + pythonize(py, &self).expect("valid condition") + } +} + +/// The body of a [`Condition::EventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct EventMatchCondition { + key: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pattern_type: Option>, +} + +/// The collection of push rules for a user. +#[derive(Debug, Clone, Default)] +#[pyclass(frozen)] +struct PushRules { + /// Custom push rules that override a base rule. + overridden_base_rules: HashMap, PushRule>, + + /// Custom rules that come between the prepend/append override base rules. + override_rules: Vec, + /// Custom rules that come before the base content rules. + content: Vec, + /// Custom rules that come before the base room rules. + room: Vec, + /// Custom rules that come before the base sender rules. + sender: Vec, + /// Custom rules that come before the base underride rules. + underride: Vec, +} + +#[pymethods] +impl PushRules { + #[new] + fn new(rules: Vec) -> PushRules { + let mut push_rules: PushRules = Default::default(); + + for rule in rules { + if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) { + push_rules.overridden_base_rules.insert( + rule.rule_id.clone(), + PushRule { + actions: rule.actions.clone(), + ..o.clone() + }, + ); + + continue; + } + + match rule.priority_class { + 5 => push_rules.override_rules.push(rule), + 4 => push_rules.content.push(rule), + 3 => push_rules.room.push(rule), + 2 => push_rules.sender.push(rule), + 1 => push_rules.underride.push(rule), + _ => { + warn!( + "Unrecognized priority class for rule {}: {}", + rule.rule_id, rule.priority_class + ); + } + } + } + + push_rules + } + + /// Returns the list of all rules, including base rules, in the order they + /// should be executed in. + fn rules(&self) -> Vec { + self.iter().cloned().collect() + } +} + +impl PushRules { + /// Iterates over all the rules, including base rules, in the order they + /// should be executed in. + pub fn iter(&self) -> impl Iterator { + base_rules::BASE_PREPEND_OVERRIDE_RULES + .iter() + .chain(self.override_rules.iter()) + .chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter()) + .chain(self.content.iter()) + .chain(base_rules::BASE_APPEND_CONTENT_RULES.iter()) + .chain(self.room.iter()) + .chain(self.sender.iter()) + .chain(self.underride.iter()) + .chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter()) + .map(|rule| { + self.overridden_base_rules + .get(&*rule.rule_id) + .unwrap_or(rule) + }) + } +} + +/// A wrapper around `PushRules` that checks the enabled state of rules and +/// filters out disabled experimental rules. +#[derive(Debug, Clone, Default)] +#[pyclass(frozen)] +pub struct FilteredPushRules { + push_rules: PushRules, + enabled_map: BTreeMap, + msc3786_enabled: bool, + msc3772_enabled: bool, +} + +#[pymethods] +impl FilteredPushRules { + #[new] + fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3786_enabled: bool, + msc3772_enabled: bool, + ) -> Self { + Self { + push_rules, + enabled_map, + msc3786_enabled, + msc3772_enabled, + } + } + + /// Returns the list of all rules and their enabled state, including base + /// rules, in the order they should be executed in. + fn rules(&self) -> Vec<(PushRule, bool)> { + self.iter().map(|(r, e)| (r.clone(), e)).collect() + } +} + +impl FilteredPushRules { + /// Iterates over all the rules and their enabled state, including base + /// rules, in the order they should be executed in. + fn iter(&self) -> impl Iterator { + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3786_enabled + && rule.rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" + { + return false; + } + + if !self.msc3772_enabled + && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) + } +} + +#[test] +fn test_serialize_condition() { + let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: "content.body".into(), + pattern: Some("coffee".into()), + pattern_type: None, + })); + + let json = serde_json::to_string(&condition).unwrap(); + assert_eq!( + json, + r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"# + ) +} + +#[test] +fn test_deserialize_condition() { + let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#; + + let _: Condition = serde_json::from_str(json).unwrap(); +} + +#[test] +fn test_deserialize_custom_condition() { + let json = r#"{"kind":"custom_tag"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!(condition, Condition::Unknown(_))); + + let new_json = serde_json::to_string(&condition).unwrap(); + assert_eq!(json, new_json); +} + +#[test] +fn test_deserialize_action() { + let _: Action = serde_json::from_str(r#""notify""#).unwrap(); + let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap(); + let _: Action = serde_json::from_str(r#""coalesce""#).unwrap(); + let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap(); +} + +#[test] +fn test_custom_action() { + let json = r#"{"some_custom":"action_fields"}"#; + + let action: Action = serde_json::from_str(json).unwrap(); + assert!(matches!(action, Action::Unknown(_))); + + let new_json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, new_json); +} diff --git a/stubs/synapse/synapse_rust.pyi b/stubs/synapse/synapse_rust.pyi deleted file mode 100644 index 8658d3138f..0000000000 --- a/stubs/synapse/synapse_rust.pyi +++ /dev/null @@ -1,2 +0,0 @@ -def sum_as_string(a: int, b: int) -> str: ... -def get_rust_file_digest() -> str: ... diff --git a/stubs/synapse/synapse_rust/__init__.pyi b/stubs/synapse/synapse_rust/__init__.pyi new file mode 100644 index 0000000000..8658d3138f --- /dev/null +++ b/stubs/synapse/synapse_rust/__init__.pyi @@ -0,0 +1,2 @@ +def sum_as_string(a: int, b: int) -> str: ... +def get_rust_file_digest() -> str: ... diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi new file mode 100644 index 0000000000..93c4e69d42 --- /dev/null +++ b/stubs/synapse/synapse_rust/push.pyi @@ -0,0 +1,37 @@ +from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union + +from synapse.types import JsonDict + +class PushRule: + @property + def rule_id(self) -> str: ... + @property + def priority_class(self) -> int: ... + @property + def conditions(self) -> Sequence[Mapping[str, str]]: ... + @property + def actions(self) -> Sequence[Union[Mapping[str, Any], str]]: ... + @property + def default(self) -> bool: ... + @property + def default_enabled(self) -> bool: ... + @staticmethod + def from_db( + rule_id: str, priority_class: int, conditions: str, actions: str + ) -> "PushRule": ... + +class PushRules: + def __init__(self, rules: Collection[PushRule]): ... + def rules(self) -> Collection[PushRule]: ... + +class FilteredPushRules: + def __init__( + self, + push_rules: PushRules, + enabled_map: Dict[str, bool], + msc3786_enabled: bool, + msc3772_enabled: bool, + ): ... + def rules(self) -> Collection[Tuple[PushRule, bool]]: ... + +def get_base_rule_ids() -> Collection[str]: ... diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py index 2599160bcc..1219672a59 100644 --- a/synapse/handlers/push_rules.py +++ b/synapse/handlers/push_rules.py @@ -16,14 +16,17 @@ from typing import TYPE_CHECKING, List, Optional, Union import attr from synapse.api.errors import SynapseError, UnrecognizedRequestError -from synapse.push.baserules import BASE_RULE_IDS from synapse.storage.push_rule import RuleNotFoundException +from synapse.synapse_rust.push import get_base_rule_ids from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer +BASE_RULE_IDS = get_base_rule_ids() + + @attr.s(slots=True, frozen=True, auto_attribs=True) class RuleSpec: scope: str diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py deleted file mode 100644 index 440205e80c..0000000000 --- a/synapse/push/baserules.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Push rules is the system used to determine which events trigger a push (and a -bump in notification counts). - -This consists of a list of "push rules" for each user, where a push rule is a -pair of "conditions" and "actions". When a user receives an event Synapse -iterates over the list of push rules until it finds one where all the conditions -match the event, at which point "actions" describe the outcome (e.g. notify, -highlight, etc). - -Push rules are split up into 5 different "kinds" (aka "priority classes"), which -are run in order: - 1. Override — highest priority rules, e.g. always ignore notices - 2. Content — content specific rules, e.g. @ notifications - 3. Room — per room rules, e.g. enable/disable notifications for all messages - in a room - 4. Sender — per sender rules, e.g. never notify for messages from a given - user - 5. Underride — the lowest priority "default" rules, e.g. notify for every - message. - -The set of "base rules" are the list of rules that every user has by default. A -user can modify their copy of the push rules in one of three ways: - - 1. Adding a new push rule of a certain kind - 2. Changing the actions of a base rule - 3. Enabling/disabling a base rule. - -The base rules are split into whether they come before or after a particular -kind, so the order of push rule evaluation would be: base rules for before -"override" kind, user defined "override" rules, base rules after "override" -kind, etc, etc. -""" - -import itertools -import logging -from typing import Dict, Iterator, List, Mapping, Sequence, Tuple, Union - -import attr - -from synapse.config.experimental import ExperimentalConfig -from synapse.push.rulekinds import PRIORITY_CLASS_MAP - -logger = logging.getLogger(__name__) - - -@attr.s(auto_attribs=True, slots=True, frozen=True) -class PushRule: - """A push rule - - Attributes: - rule_id: a unique ID for this rule - priority_class: what "kind" of push rule this is (see - `PRIORITY_CLASS_MAP` for mapping between int and kind) - conditions: the sequence of conditions that all need to match - actions: the actions to apply if all conditions are met - default: is this a base rule? - default_enabled: is this enabled by default? - """ - - rule_id: str - priority_class: int - conditions: Sequence[Mapping[str, str]] - actions: Sequence[Union[str, Mapping]] - default: bool = False - default_enabled: bool = True - - -@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False) -class PushRules: - """A collection of push rules for an account. - - Can be iterated over, producing push rules in priority order. - """ - - # A mapping from rule ID to push rule that overrides a base rule. These will - # be returned instead of the base rule. - overriden_base_rules: Dict[str, PushRule] = attr.Factory(dict) - - # The following stores the custom push rules at each priority class. - # - # We keep these separate (rather than combining into one big list) to avoid - # copying the base rules around all the time. - override: List[PushRule] = attr.Factory(list) - content: List[PushRule] = attr.Factory(list) - room: List[PushRule] = attr.Factory(list) - sender: List[PushRule] = attr.Factory(list) - underride: List[PushRule] = attr.Factory(list) - - def __iter__(self) -> Iterator[PushRule]: - # When iterating over the push rules we need to return the base rules - # interspersed at the correct spots. - for rule in itertools.chain( - BASE_PREPEND_OVERRIDE_RULES, - self.override, - BASE_APPEND_OVERRIDE_RULES, - self.content, - BASE_APPEND_CONTENT_RULES, - self.room, - self.sender, - self.underride, - BASE_APPEND_UNDERRIDE_RULES, - ): - # Check if a base rule has been overriden by a custom rule. If so - # return that instead. - override_rule = self.overriden_base_rules.get(rule.rule_id) - if override_rule: - yield override_rule - else: - yield rule - - def __len__(self) -> int: - # The length is mostly used by caches to get a sense of "size" / amount - # of memory this object is using, so we only count the number of custom - # rules. - return ( - len(self.overriden_base_rules) - + len(self.override) - + len(self.content) - + len(self.room) - + len(self.sender) - + len(self.underride) - ) - - -@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False) -class FilteredPushRules: - """A wrapper around `PushRules` that filters out disabled experimental push - rules, and includes the "enabled" state for each rule when iterated over. - """ - - push_rules: PushRules - enabled_map: Dict[str, bool] - experimental_config: ExperimentalConfig - - def __iter__(self) -> Iterator[Tuple[PushRule, bool]]: - for rule in self.push_rules: - if not _is_experimental_rule_enabled( - rule.rule_id, self.experimental_config - ): - continue - - enabled = self.enabled_map.get(rule.rule_id, rule.default_enabled) - - yield rule, enabled - - def __len__(self) -> int: - return len(self.push_rules) - - -DEFAULT_EMPTY_PUSH_RULES = PushRules() - - -def compile_push_rules(rawrules: List[PushRule]) -> PushRules: - """Given a set of custom push rules return a `PushRules` instance (which - includes the base rules). - """ - - if not rawrules: - # Fast path to avoid allocating empty lists when there are no custom - # rules for the user. - return DEFAULT_EMPTY_PUSH_RULES - - rules = PushRules() - - for rule in rawrules: - # We need to decide which bucket each custom push rule goes into. - - # If it has the same ID as a base rule then it overrides that... - overriden_base_rule = BASE_RULES_BY_ID.get(rule.rule_id) - if overriden_base_rule: - rules.overriden_base_rules[rule.rule_id] = attr.evolve( - overriden_base_rule, actions=rule.actions - ) - continue - - # ... otherwise it gets added to the appropriate priority class bucket - collection: List[PushRule] - if rule.priority_class == 5: - collection = rules.override - elif rule.priority_class == 4: - collection = rules.content - elif rule.priority_class == 3: - collection = rules.room - elif rule.priority_class == 2: - collection = rules.sender - elif rule.priority_class == 1: - collection = rules.underride - elif rule.priority_class <= 0: - logger.info( - "Got rule with priority class less than zero, but doesn't override a base rule: %s", - rule, - ) - continue - else: - # We log and continue here so as not to break event sending - logger.error("Unknown priority class: %", rule.priority_class) - continue - - collection.append(rule) - - return rules - - -def _is_experimental_rule_enabled( - rule_id: str, experimental_config: ExperimentalConfig -) -> bool: - """Used by `FilteredPushRules` to filter out experimental rules when they - have not been enabled. - """ - if ( - rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" - and not experimental_config.msc3786_enabled - ): - return False - if ( - rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - and not experimental_config.msc3772_enabled - ): - return False - return True - - -BASE_APPEND_CONTENT_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["content"], - rule_id="global/content/.m.rule.contains_user_name", - conditions=[ - { - "kind": "event_match", - "key": "content.body", - # Match the localpart of the requester's MXID. - "pattern_type": "user_localpart", - } - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, - ], - ) -] - - -BASE_PREPEND_OVERRIDE_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.master", - default_enabled=False, - conditions=[], - actions=["dont_notify"], - ) -] - - -BASE_APPEND_OVERRIDE_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.suppress_notices", - conditions=[ - { - "kind": "event_match", - "key": "content.msgtype", - "pattern": "m.notice", - "_cache_key": "_suppress_notices", - } - ], - actions=["dont_notify"], - ), - # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event - # otherwise invites will be matched by .m.rule.member_event - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.invite_for_me", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.member", - "_cache_key": "_member", - }, - { - "kind": "event_match", - "key": "content.membership", - "pattern": "invite", - "_cache_key": "_invite_member", - }, - # Match the requester's MXID. - {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - # Will we sometimes want to know about people joining and leaving? - # Perhaps: if so, this could be expanded upon. Seems the most usual case - # is that we don't though. We add this override rule so that even if - # the room rule is set to notify, we don't get notifications about - # join/leave/avatar/displayname events. - # See also: https://matrix.org/jira/browse/SYN-607 - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.member_event", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.member", - "_cache_key": "_member", - } - ], - actions=["dont_notify"], - ), - # This was changed from underride to override so it's closer in priority - # to the content rules where the user name highlight rule lives. This - # way a room rule is lower priority than both but a custom override rule - # is higher priority than both. - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.contains_display_name", - conditions=[{"kind": "contains_display_name"}], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, - ], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.roomnotif", - conditions=[ - { - "kind": "event_match", - "key": "content.body", - "pattern": "@room", - "_cache_key": "_roomnotif_content", - }, - { - "kind": "sender_notification_permission", - "key": "room", - "_cache_key": "_roomnotif_pl", - }, - ], - actions=["notify", {"set_tweak": "highlight", "value": True}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.tombstone", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.tombstone", - "_cache_key": "_tombstone", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "", - "_cache_key": "_tombstone_statekey", - }, - ], - actions=["notify", {"set_tweak": "highlight", "value": True}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.reaction", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.reaction", - "_cache_key": "_reaction", - } - ], - actions=["dont_notify"], - ), - # XXX: This is an experimental rule that is only enabled if msc3786_enabled - # is enabled, if it is not the rule gets filtered out in _load_rules() in - # PushRulesWorkerStore - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.org.matrix.msc3786.rule.room.server_acl", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.server_acl", - "_cache_key": "_room_server_acl", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "", - "_cache_key": "_room_server_acl_state_key", - }, - ], - actions=[], - ), -] - - -BASE_APPEND_UNDERRIDE_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.call", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.call.invite", - "_cache_key": "_call", - } - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "ring"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - # XXX: once m.direct is standardised everywhere, we should use it to detect - # a DM from the user's perspective rather than this heuristic. - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.room_one_to_one", - conditions=[ - {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"}, - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.message", - "_cache_key": "_message", - }, - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - # XXX: this is going to fire for events which aren't m.room.messages - # but are encrypted (e.g. m.call.*)... - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.encrypted_room_one_to_one", - conditions=[ - {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"}, - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.encrypted", - "_cache_key": "_encrypted", - }, - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.org.matrix.msc3772.thread_reply", - conditions=[ - { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.thread", - # Match the requester's MXID. - "sender_type": "user_id", - } - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.message", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.message", - "_cache_key": "_message", - } - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), - # XXX: this is going to fire for events which aren't m.room.messages - # but are encrypted (e.g. m.call.*)... - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.encrypted", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.encrypted", - "_cache_key": "_encrypted", - } - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.im.vector.jitsi", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "im.vector.modular.widgets", - "_cache_key": "_type_modular_widgets", - }, - { - "kind": "event_match", - "key": "content.type", - "pattern": "jitsi", - "_cache_key": "_content_type_jitsi", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "*", - "_cache_key": "_is_state_event", - }, - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), -] - - -BASE_RULE_IDS = set() - -BASE_RULES_BY_ID: Dict[str, PushRule] = {} - -for r in BASE_APPEND_CONTENT_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r - -for r in BASE_PREPEND_OVERRIDE_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r - -for r in BASE_APPEND_OVERRIDE_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r - -for r in BASE_APPEND_UNDERRIDE_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 3846fbc5f0..404379ef67 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -37,11 +37,11 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter +from synapse.synapse_rust.push import FilteredPushRules, PushRule from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state -from .baserules import FilteredPushRules, PushRule from .push_rule_evaluator import PushRuleEvaluatorForEvent if TYPE_CHECKING: @@ -280,7 +280,8 @@ class BulkPushRuleEvaluator: thread_id = "main" if relation: relations = await self._get_mutual_relations( - relation.parent_id, itertools.chain(*rules_by_user.values()) + relation.parent_id, + itertools.chain(*(r.rules() for r in rules_by_user.values())), ) if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -333,7 +334,7 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - for rule, enabled in rules: + for rule, enabled in rules.rules(): if not enabled: continue diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 73618d9234..ebc13beda1 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -16,10 +16,9 @@ import copy from typing import Any, Dict, List, Optional from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP +from synapse.synapse_rust.push import FilteredPushRules, PushRule from synapse.types import UserID -from .baserules import FilteredPushRules, PushRule - def format_push_rules_for_user( user: UserID, ruleslist: FilteredPushRules @@ -34,7 +33,7 @@ def format_push_rules_for_user( rules["global"] = _add_empty_priority_class_arrays(rules["global"]) - for r, enabled in ruleslist: + for r, enabled in ruleslist.rules(): template_name = _priority_class_to_template_name(r.priority_class) rulearray = rules["global"][template_name] diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 5079edd1e0..ed17b2e70c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,9 +30,8 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -51,6 +50,7 @@ from synapse.storage.util.id_generators import ( IdGenerator, StreamIdGenerator, ) +from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -72,18 +72,25 @@ def _load_rules( """ ruleslist = [ - PushRule( + PushRule.from_db( rule_id=rawrule["rule_id"], priority_class=rawrule["priority_class"], - conditions=db_to_json(rawrule["conditions"]), - actions=db_to_json(rawrule["actions"]), + conditions=rawrule["conditions"], + actions=rawrule["actions"], ) for rawrule in rawrules ] - push_rules = compile_push_rules(ruleslist) + push_rules = PushRules( + ruleslist, + ) - filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config) + filtered_rules = FilteredPushRules( + push_rules, + enabled_map, + msc3786_enabled=experimental_config.msc3786_enabled, + msc3772_enabled=experimental_config.msc3772_enabled, + ) return filtered_rules @@ -845,7 +852,7 @@ class PushRuleStore(PushRulesWorkerStore): user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room - for rule, enabled in user_push_rules: + for rule, enabled in user_push_rules.rules(): if not enabled: continue diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 7b9b711521..bce65fab7d 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -15,11 +15,11 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import AccountDataTypes -from synapse.push.baserules import PushRule from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest import admin from synapse.rest.client import account, login from synapse.server import HomeServer +from synapse.synapse_rust.push import PushRule from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -161,20 +161,15 @@ class DeactivateAccountTestCase(HomeserverTestCase): self._store.get_push_rules_for_user(self.user) ) # Filter out default rules; we don't care - push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)] + push_rules = [ + r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r) + ] # Check our rule made it - self.assertEqual( - push_rules, - [ - PushRule( - rule_id="personal.override.rule1", - priority_class=5, - conditions=[], - actions=[], - ) - ], - push_rules, - ) + self.assertEqual(len(push_rules), 1) + self.assertEqual(push_rules[0].rule_id, "personal.override.rule1") + self.assertEqual(push_rules[0].priority_class, 5) + self.assertEqual(push_rules[0].conditions, []) + self.assertEqual(push_rules[0].actions, []) # Request the deactivation of our account self._deactivate_my_account() @@ -183,7 +178,9 @@ class DeactivateAccountTestCase(HomeserverTestCase): self._store.get_push_rules_for_user(self.user) ) # Filter out default rules; we don't care - push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)] + push_rules = [ + r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r) + ] # Check our rule no longer exists self.assertEqual(push_rules, [], push_rules) -- cgit 1.5.1 From fff9b955fa39bda2cca1fa726b561c7886e746a1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 20 Sep 2022 14:14:12 +0100 Subject: Generate separate snapshots for logical databases (#13792) * Generate separate snapshots for sqlite, postgres and common * Cleanup postgres dbs in the TRAP * Say which logical DB we're applying updates to * Run background updates on the state DB * Add new option for accepting a SCHEMA_NUMBER --- changelog.d/13792.misc | 1 + scripts-dev/make_full_schema.sh | 166 +++++++++++++++++++++------- synapse/_scripts/update_synapse_database.py | 14 ++- synapse/storage/background_updates.py | 5 +- 4 files changed, 140 insertions(+), 46 deletions(-) create mode 100644 changelog.d/13792.misc (limited to 'synapse') diff --git a/changelog.d/13792.misc b/changelog.d/13792.misc new file mode 100644 index 0000000000..36ac91400a --- /dev/null +++ b/changelog.d/13792.misc @@ -0,0 +1 @@ +Update the script which makes full schema dumps. diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh index 61394360ce..d8cd06ee4f 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh @@ -2,23 +2,16 @@ # # This script generates SQL files for creating a brand new Synapse DB with the latest # schema, on both SQLite3 and Postgres. -# -# It does so by having Synapse generate an up-to-date SQLite DB, then running -# synapse_port_db to convert it to Postgres. It then dumps the contents of both. export PGHOST="localhost" -POSTGRES_DB_NAME="synapse_full_schema.$$" - -SQLITE_SCHEMA_FILE="schema.sql.sqlite" -SQLITE_ROWS_FILE="rows.sql.sqlite" -POSTGRES_SCHEMA_FILE="full.sql.postgres" -POSTGRES_ROWS_FILE="rows.sql.postgres" - +POSTGRES_MAIN_DB_NAME="synapse_full_schema_main.$$" +POSTGRES_COMMON_DB_NAME="synapse_full_schema_common.$$" +POSTGRES_STATE_DB_NAME="synapse_full_schema_state.$$" REQUIRED_DEPS=("matrix-synapse" "psycopg2") usage() { echo - echo "Usage: $0 -p -o [-c] [-n] [-h]" + echo "Usage: $0 -p -o [-c] [-n ] [-h]" echo echo "-p " echo " Username to connect to local postgres instance. The password will be requested" @@ -27,11 +20,16 @@ usage() { echo " CI mode. Prints every command that the script runs." echo "-o " echo " Directory to output full schema files to." + echo "-n " + echo " Schema number for the new snapshot. Used to set the location of files within " + echo " the output directory, mimicking that of synapse/storage/schemas." + echo " Defaults to 9999." echo "-h" echo " Display this help text." } -while getopts "p:co:h" opt; do +SCHEMA_NUMBER="9999" +while getopts "p:co:hn:" opt; do case $opt in p) export PGUSER=$OPTARG @@ -48,6 +46,9 @@ while getopts "p:co:h" opt; do usage exit ;; + n) + SCHEMA_NUMBER="$OPTARG" + ;; \?) echo "ERROR: Invalid option: -$OPTARG" >&2 usage @@ -95,12 +96,21 @@ cd "$(dirname "$0")/.." TMPDIR=$(mktemp -d) KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path SQLITE_CONFIG=$TMPDIR/sqlite.conf -SQLITE_DB=$TMPDIR/homeserver.db +SQLITE_MAIN_DB=$TMPDIR/main.db +SQLITE_STATE_DB=$TMPDIR/state.db +SQLITE_COMMON_DB=$TMPDIR/common.db POSTGRES_CONFIG=$TMPDIR/postgres.conf # Ensure these files are delete on script exit -# TODO: the trap should also drop the temp postgres DB -trap 'rm -rf $TMPDIR' EXIT +cleanup() { + echo "Cleaning up temporary sqlite database and config files..." + rm -r "$TMPDIR" + echo "Cleaning up temporary Postgres database..." + dropdb --if-exists "$POSTGRES_COMMON_DB_NAME" + dropdb --if-exists "$POSTGRES_MAIN_DB_NAME" + dropdb --if-exists "$POSTGRES_STATE_DB_NAME" +} +trap 'cleanup' EXIT cat > "$SQLITE_CONFIG" < "$OUTPUT_DIR/$SQLITE_SCHEMA_FILE" -sqlite3 "$SQLITE_DB" ".dump --data-only --nosys" > "$OUTPUT_DIR/$SQLITE_ROWS_FILE" +echo "Dumping SQLite3 schema..." + +mkdir -p "$OUTPUT_DIR/"{common,main,state}"/full_schema/$SCHEMA_NUMBER" +sqlite3 "$SQLITE_COMMON_DB" ".schema --indent" > "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_COMMON_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_MAIN_DB" ".schema --indent" > "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_MAIN_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_STATE_DB" ".schema --indent" > "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_STATE_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" + +cleanup_pg_schema() { + sed -e '/^$/d' -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' +} -echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_SCHEMA_FILE' and '$OUTPUT_DIR/$POSTGRES_ROWS_FILE'..." -pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_DB_NAME" | sed -e '/^$/d' -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_SCHEMA_FILE" -pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_DB_NAME" | sed -e '/^$/d' -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_ROWS_FILE" +echo "Dumping Postgres schema..." -echo "Cleaning up temporary Postgres database..." -dropdb $POSTGRES_DB_NAME +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_COMMON_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_COMMON_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_MAIN_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_MAIN_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.postgres" echo "Done! Files dumped to: $OUTPUT_DIR" diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index b4aeae6dd5..fb1fb83f50 100755 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -48,10 +48,13 @@ class MockHomeserver(HomeServer): def run_background_updates(hs: HomeServer) -> None: - store = hs.get_datastores().main + main = hs.get_datastores().main + state = hs.get_datastores().state async def run_background_updates() -> None: - await store.db_pool.updates.run_background_updates(sleep=False) + await main.db_pool.updates.run_background_updates(sleep=False) + if state: + await state.db_pool.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() @@ -97,8 +100,11 @@ def main() -> None: # Load, process and sanity-check the config. hs_config = yaml.safe_load(args.database_config) - if "database" not in hs_config: - sys.stderr.write("The configuration file must have a 'database' section.\n") + if "database" not in hs_config and "databases" not in hs_config: + sys.stderr.write( + "The configuration file must have a 'database' or 'databases' section. " + "See https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#database" + ) sys.exit(4) config = HomeServerConfig() diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index bf5e7ee7be..2056ecb2c3 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -285,7 +285,10 @@ class BackgroundUpdater: back_to_back_failures = 0 try: - logger.info("Starting background schema updates") + logger.info( + "Starting background schema updates for database %s", + self._database_name, + ) while self.enabled: try: result = await self.do_next_background_update(sleep) -- cgit 1.5.1 From 85fc7ea1a1fb38424923dd1ff117405aea04c33c Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 20 Sep 2022 15:18:07 +0200 Subject: Remove the `complete_sso_login` method from the Module API which was deprecated in Synapse 1.13.0. (#13843) Signed-off-by: Quentin Gliech --- changelog.d/13843.removal | 1 + synapse/handlers/auth.py | 34 +--------------------------------- synapse/module_api/__init__.py | 25 ------------------------- 3 files changed, 2 insertions(+), 58 deletions(-) create mode 100644 changelog.d/13843.removal (limited to 'synapse') diff --git a/changelog.d/13843.removal b/changelog.d/13843.removal new file mode 100644 index 0000000000..f6caaa8895 --- /dev/null +++ b/changelog.d/13843.removal @@ -0,0 +1 @@ +Remove the `complete_sso_login` method from the Module API which was deprecated in Synapse 1.13.0. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0327fc57a4..eacd631ee0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -63,7 +63,6 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable @@ -1687,41 +1686,10 @@ class AuthHandler: respond_with_html(request, 403, self._sso_account_deactivated_template) return - profile = await self.store.get_profileinfo( + user_profile_data = await self.store.get_profileinfo( UserID.from_string(registered_user_id).localpart ) - self._complete_sso_login( - registered_user_id, - auth_provider_id, - request, - client_redirect_url, - extra_attributes, - new_user=new_user, - user_profile_data=profile, - auth_provider_session_id=auth_provider_session_id, - ) - - def _complete_sso_login( - self, - registered_user_id: str, - auth_provider_id: str, - request: Request, - client_redirect_url: str, - extra_attributes: Optional[JsonDict] = None, - new_user: bool = False, - user_profile_data: Optional[ProfileInfo] = None, - auth_provider_session_id: Optional[str] = None, - ) -> None: - """ - The synchronous portion of complete_sso_login. - - This exists purely for backwards compatibility of synapse.module_api.ModuleApi. - """ - - if user_profile_data is None: - user_profile_data = ProfileInfo(None, None) - # Store any extra attributes which will be passed in the login response. # Note that this is per-user so it may overwrite a previous value, this # is considered OK since the newest SSO attributes should be most valid. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 87ba154cb7..9287c0fb8d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -836,31 +836,6 @@ class ModuleApi: self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] ) - def complete_sso_login( - self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str - ) -> None: - """Complete a SSO login by redirecting the user to a page to confirm whether they - want their access token sent to `client_redirect_url`, or redirect them to that - URL with a token directly if the URL matches with one of the whitelisted clients. - - This is deprecated in favor of complete_sso_login_async. - - Added in Synapse v1.11.1. - - Args: - registered_user_id: The MXID that has been registered as a previous step of - of this SSO login. - request: The request to respond to. - client_redirect_url: The URL to which to offer to redirect the user (or to - redirect them directly if whitelisted). - """ - self._auth_handler._complete_sso_login( - registered_user_id, - "", - request, - client_redirect_url, - ) - async def complete_sso_login_async( self, registered_user_id: str, -- cgit 1.5.1 From 16e1a9d9a7884967da390ef967b942a5e35e8f6c Mon Sep 17 00:00:00 2001 From: Peter Scheu <32014443+peterscheu-aceart@users.noreply.github.com> Date: Wed, 21 Sep 2022 15:08:16 +0200 Subject: Correct documentation for map_user_attributes of OpenID Mapping Providers (#13836) Co-authored-by: David Robertson --- changelog.d/13836.doc | 1 + docs/sso_mapping_providers.md | 12 +++++++++--- synapse/handlers/sso.py | 3 +++ 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13836.doc (limited to 'synapse') diff --git a/changelog.d/13836.doc b/changelog.d/13836.doc new file mode 100644 index 0000000000..f2edab00f4 --- /dev/null +++ b/changelog.d/13836.doc @@ -0,0 +1 @@ +Fix a mistake in sso_mapping_providers.md: `map_user_attributes` is expected to return `display_name` not `displayname`. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 817499149f..9f5e5fbbe1 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -73,8 +73,8 @@ A custom mapping provider must specify the following methods: * `async def map_user_attributes(self, userinfo, token, failures)` - This method must be async. - Arguments: - - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user - information from. + - `userinfo` - An [`authlib.oidc.core.claims.UserInfo`](https://docs.authlib.org/en/latest/specs/oidc.html#authlib.oidc.core.UserInfo) + object to extract user information from. - `token` - A dictionary which includes information necessary to make further requests to the OpenID provider. - `failures` - An `int` that represents the amount of times the returned @@ -91,7 +91,13 @@ A custom mapping provider must specify the following methods: `None`, the user is prompted to pick their own username. This is only used during a user's first login. Once a localpart has been associated with a remote user ID (see `get_remote_user_id`) it cannot be updated. - - `displayname`: An optional string, the display name for the user. + - `confirm_localpart`: A boolean. If set to `True`, when a `localpart` + string is returned from this method, Synapse will prompt the user to + either accept this localpart or pick their own username. Otherwise this + option has no effect. If omitted, defaults to `False`. + - `display_name`: An optional string, the display name for the user. + - `emails`: A list of strings, the email address(es) to associate with + this user. If omitted, defaults to an empty list. * `async def get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 1e171f3f71..6bc1cbd787 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -128,6 +128,9 @@ class SsoIdentityProvider(Protocol): @attr.s(auto_attribs=True) class UserAttributes: + # NB: This struct is documented in docs/sso_mapping_providers.md so that users can + # populate it with data from their own mapping providers. + # the localpart of the mxid that the mapper has assigned to the user. # if `None`, the mapper has not picked a userid, and the user should be prompted to # enter one. -- cgit 1.5.1 From 6bd8763804dc0987c7ecd37bcb5ebff465fffa29 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 21 Sep 2022 15:32:01 +0200 Subject: Add cache invalidation across workers to module API (#13667) Signed-off-by: Mathieu Velten --- changelog.d/13667.feature | 1 + scripts-dev/mypy_synapse_plugin.py | 4 +- synapse/module_api/__init__.py | 33 ++++++++- synapse/storage/_base.py | 23 +++++-- synapse/storage/databases/main/cache.py | 20 ++++-- synapse/util/caches/descriptors.py | 14 ++-- .../replication/test_module_cache_invalidation.py | 79 ++++++++++++++++++++++ 7 files changed, 153 insertions(+), 21 deletions(-) create mode 100644 changelog.d/13667.feature create mode 100644 tests/replication/test_module_cache_invalidation.py (limited to 'synapse') diff --git a/changelog.d/13667.feature b/changelog.d/13667.feature new file mode 100644 index 0000000000..a0b3cfe18c --- /dev/null +++ b/changelog.d/13667.feature @@ -0,0 +1 @@ +Add cache invalidation across workers to module API. diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index d08517a953..2c377533c0 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -29,7 +29,7 @@ class SynapsePlugin(Plugin): self, fullname: str ) -> Optional[Callable[[MethodSigContext], CallableType]]: if fullname.startswith( - "synapse.util.caches.descriptors._CachedFunction.__call__" + "synapse.util.caches.descriptors.CachedFunction.__call__" ) or fullname.startswith( "synapse.util.caches.descriptors._LruCachedFunction.__call__" ): @@ -38,7 +38,7 @@ class SynapsePlugin(Plugin): def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: - """Fixes the `_CachedFunction.__call__` signature to be correct. + """Fixes the `CachedFunction.__call__` signature to be correct. It already has *almost* the correct signature, except: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 9287c0fb8d..59755bff6d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -125,7 +125,7 @@ from synapse.types import ( ) from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import CachedFunction, cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -836,6 +836,37 @@ class ModuleApi: self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] ) + def register_cached_function(self, cached_func: CachedFunction) -> None: + """Register a cached function that should be invalidated across workers. + Invalidation local to a worker can be done directly using `cached_func.invalidate`, + however invalidation that needs to go to other workers needs to call `invalidate_cache` + on the module API instead. + + Args: + cached_function: The cached function that will be registered to receive invalidation + locally and from other workers. + """ + self._store.register_external_cached_function( + f"{cached_func.__module__}.{cached_func.__name__}", cached_func + ) + + async def invalidate_cache( + self, cached_func: CachedFunction, keys: Tuple[Any, ...] + ) -> None: + """Invalidate a cache entry of a cached function across workers. The cached function + needs to be registered on all workers first with `register_cached_function`. + + Args: + cached_function: The cached function that needs an invalidation + keys: keys of the entry to invalidate, usually matching the arguments of the + cached function. + """ + cached_func.invalidate(keys) + await self._store.send_invalidation_to_replication( + f"{cached_func.__module__}.{cached_func.__name__}", + keys, + ) + async def complete_sso_login_async( self, registered_user_id: str, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e30f9c76d4..303a5d5298 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,12 +15,13 @@ # limitations under the License. import logging from abc import ABCMeta -from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import get_domain_from_id from synapse.util import json_decoder +from synapse.util.caches.descriptors import CachedFunction if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta): self.database_engine = database.engine self.db_pool = database + self.external_cached_functions: Dict[str, CachedFunction] = {} + def process_replication_rows( self, stream_name: str, @@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ) -> None: + ) -> bool: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta): try: cache = getattr(self, cache_name) except AttributeError: - # We probably haven't pulled in the cache in this worker, - # which is fine. - return + # Check if an externally defined module cache has been registered + cache = self.external_cached_functions.get(cache_name) + if not cache: + # We probably haven't pulled in the cache in this worker, + # which is fine. + return False if key is None: cache.invalidate_all() @@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta): invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) invalidate_method(tuple(key)) + return True + + def register_external_cached_function( + self, cache_name: str, func: CachedFunction + ) -> None: + self.external_cached_functions[cache_name] = func + def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 12e9a42382..2c421151c1 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -33,7 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.util.caches.descriptors import _CachedFunction +from synapse.util.caches.descriptors import CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db_pool.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, + await self.send_invalidation_to_replication( cache_func.__name__, keys, ) @@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_cache_and_stream( self, txn: LoggingTransaction, - cache_func: _CachedFunction, + cache_func: CachedFunction, keys: Tuple[Any, ...], ) -> None: """Invalidates the cache and adds it to the cache stream so slaves @@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._send_invalidation_to_replication(txn, cache_func.__name__, keys) def _invalidate_all_cache_and_stream( - self, txn: LoggingTransaction, cache_func: _CachedFunction + self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) + async def send_invalidation_to_replication( + self, cache_name: str, keys: Optional[Collection[Any]] + ) -> None: + await self.db_pool.runInteraction( + "send_invalidation_to_replication", + self._send_invalidation_to_replication, + cache_name, + keys, + ) + def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] ) -> None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 10aff4d04a..3909f1caea 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Generic[F]): +class CachedFunction(Generic[F]): invalidate: Any = None invalidate_all: Any = None prefill: Any = None @@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): return ret2 - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) wrapped.cache = cache obj.__dict__[self.name] = wrapped @@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return make_deferred_yieldable(ret) - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) if self.num_args == 1: assert not self.tree @@ -572,7 +572,7 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, @@ -585,7 +585,7 @@ def cached( name=name, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def cachedList( @@ -594,7 +594,7 @@ def cachedList( list_name: str, num_args: Optional[int] = None, name: Optional[str] = None, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -631,7 +631,7 @@ def cachedList( name=name, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def _get_cache_key_builder( diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py new file mode 100644 index 0000000000..b93cae67d3 --- /dev/null +++ b/tests/replication/test_module_cache_invalidation.py @@ -0,0 +1,79 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import synapse +from synapse.module_api import cached + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + +FIRST_VALUE = "one" +SECOND_VALUE = "two" + +KEY = "mykey" + + +class TestCache: + current_value = FIRST_VALUE + + @cached() + async def cached_function(self, user_id: str) -> str: + return self.current_value + + +class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + ] + + def test_module_cache_full_invalidation(self): + main_cache = TestCache() + self.hs.get_module_api().register_cached_function(main_cache.cached_function) + + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + worker_cache = TestCache() + worker_hs.get_module_api().register_cached_function( + worker_cache.cached_function + ) + + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + main_cache.current_value = SECOND_VALUE + worker_cache.current_value = SECOND_VALUE + # No invalidation yet, should return the cached value on both the main process and the worker + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + # Full invalidation on the main process, should be replicated on the worker that + # should returned the updated value too + self.get_success( + self.hs.get_module_api().invalidate_cache( + main_cache.cached_function, (KEY,) + ) + ) + + self.assertEqual( + SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)) + ) + self.assertEqual( + SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) -- cgit 1.5.1 From 8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 15:39:01 +0100 Subject: Support enabling/disabling pushers (from MSC3881) (#13799) Partial implementation of MSC3881 --- changelog.d/13799.feature | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/config/experimental.py | 3 + synapse/handlers/register.py | 4 +- synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 81 ++++++++--- synapse/replication/tcp/client.py | 10 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/pusher.py | 18 ++- synapse/storage/databases/main/pusher.py | 69 ++++++---- .../schema/main/delta/73/02add_pusher_enabled.sql | 16 +++ tests/push/test_email.py | 4 +- tests/push/test_http.py | 148 +++++++++++++++++++-- tests/replication/test_pusher_shard.py | 2 +- tests/rest/admin/test_user.py | 2 +- 15 files changed, 294 insertions(+), 71 deletions(-) create mode 100644 changelog.d/13799.feature create mode 100644 synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql (limited to 'synapse') diff --git a/changelog.d/13799.feature b/changelog.d/13799.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13799.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 30983c47fb..450ba462ba 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = { "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], + "pushers": ["enabled"], } diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702b81e636..f4541a8db0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,6 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + + # MSC3881: Remotely toggle push notifications for another client + self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 20ec22105a..cfcadb34db 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -997,7 +997,7 @@ class RegistrationHandler: assert user_tuple token_id = user_tuple.token_id - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -1005,7 +1005,7 @@ class RegistrationHandler: app_display_name="Email Notifications", device_display_name=threepid["address"], pushkey=threepid["address"], - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 57c4d70466..ac99d35a7e 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -116,6 +116,7 @@ class PusherConfig: last_stream_ordering: int last_success: Optional[int] failing_since: Optional[int] + enabled: bool def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -128,6 +129,7 @@ class PusherConfig: "lang": self.lang, "profile_tag": self.profile_tag, "pushkey": self.pushkey, + "enabled": self.enabled, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 1e0ef44fc7..2597898cf4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -94,7 +94,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - async def add_pusher( + async def add_or_update_pusher( self, user_id: str, access_token: Optional[int], @@ -106,6 +106,7 @@ class PusherPool: lang: Optional[str], data: JsonDict, profile_tag: str = "", + enabled: bool = True, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -147,9 +148,20 @@ class PusherPool: last_stream_ordering=last_stream_ordering, last_success=None, failing_since=None, + enabled=enabled, ) ) + # Before we actually persist the pusher, we check if the user already has one + # for this app ID and pushkey. If so, we want to keep the access token in place, + # since this could be one device modifying (e.g. enabling/disabling) another + # device's pusher. + existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) + if existing_config: + access_token = existing_config.access_token + await self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -163,8 +175,9 @@ class PusherPool: data=data, last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, + enabled=enabled, ) - pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) return pusher @@ -276,10 +289,25 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - async def start_pusher_by_id( + async def _get_pusher_config_for_user_by_app_id_and_pushkey( + self, user_id: str, app_id: str, pushkey: str + ) -> Optional[PusherConfig]: + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + + pusher_config = None + for r in resultlist: + if r.user_name == user_id: + pusher_config = r + + return pusher_config + + async def process_pusher_change_by_id( self, app_id: str, pushkey: str, user_id: str ) -> Optional[Pusher]: - """Look up the details for the given pusher, and start it + """Look up the details for the given pusher, and either start it if its + "enabled" flag is True, or try to stop it otherwise. + + If the pusher is new and its "enabled" flag is False, the stop is a noop. Returns: The pusher started, if any @@ -290,12 +318,13 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return None - resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) - pusher_config = None - for r in resultlist: - if r.user_name == user_id: - pusher_config = r + if pusher_config and not pusher_config.enabled: + self.maybe_stop_pusher(app_id, pushkey, user_id) + return None pusher = None if pusher_config: @@ -305,7 +334,7 @@ class PusherPool: async def _start_pushers(self) -> None: """Start all the pushers""" - pushers = await self.store.get_all_pushers() + pushers = await self.store.get_enabled_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -363,6 +392,8 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() + logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey) + # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. @@ -382,16 +413,7 @@ class PusherPool: return pusher async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: - appid_pushkey = "%s:%s" % (app_id, pushkey) - - byuser = self.pushers.get(user_id, {}) - - if appid_pushkey in byuser: - logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - pusher = byuser.pop(appid_pushkey) - pusher.on_stop() - - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() + self.maybe_stop_pusher(app_id, pushkey, user_id) # We can only delete pushers on master. if self._remove_pusher_client: @@ -402,3 +424,22 @@ class PusherPool: await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) + + def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: + """Stops a pusher with the given app ID and push key if one is running. + + Args: + app_id: the pusher's app ID. + pushkey: the pusher's push key. + user_id: the user the pusher belongs to. Only used for logging. + """ + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..cf9cd6833b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,7 +189,9 @@ class ReplicationDataHandler: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: - await self.start_pusher(row.user_id, row.app_id, row.pushkey) + await self.process_pusher_change( + row.user_id, row.app_id, row.pushkey + ) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -334,13 +336,15 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: + async def process_pusher_change( + self, user_id: str, app_id: str, pushkey: str + ) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id) class FederationSenderHandler: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2ca6b2d08a..1274773d7e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet): and self.hs.config.email.email_notif_for_new_users and medium == "email" ): - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=None, kind="email", @@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet): app_display_name="Email Notifications", device_display_name=address, pushkey=address, - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 9a1f10f4be..c9f76125dc 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet): user.to_string() ) - filtered_pushers = [p.as_dict() for p in pushers] + pusher_dicts = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers} + for pusher in pusher_dicts: + if self._msc3881_enabled: + pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + del pusher["enabled"] + + return 200, {"pushers": pusher_dicts} class PushersSetRestServlet(RestServlet): @@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet): self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet): if "append" in content: append = content["append"] + enabled = True + if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content: + enabled = content["org.matrix.msc3881.enabled"] + if not append: await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], @@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet): ) try: - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet): lang=content["lang"], data=content["data"], profile_tag=content.get("profile_tag", ""), + enabled=enabled, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bd0cfa7f32..ee55b8c4a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore): ) continue + # If we're using SQLite, then boolean values are integers. This is + # troublesome since some code using the return value of this method might + # expect it to be a boolean, or will expose it to clients (in responses). + r["enabled"] = bool(r["enabled"]) + yield PusherConfig(**r) async def get_pushers_by_app_id_and_pushkey( @@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore): return await self.get_pushers_by({"user_name": user_id}) async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: - ret = await self.db_pool.simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], + """Retrieve pushers that match the given criteria. + + Args: + keyvalues: A {column: value} dictionary. + + Returns: + The pushers for which the given columns have the given values. + """ + + def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: + # We could technically use simple_select_list here, but we need to call + # COALESCE on the 'enabled' column. While it is technically possible to give + # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it + # feels a bit hacky, so it's probably better to just inline the query. + sql = """ + SELECT + id, user_name, access_token, profile_tag, kind, app_id, + app_display_name, device_display_name, pushkey, ts, lang, data, + last_stream_ordering, last_success, failing_since, + COALESCE(enabled, TRUE) AS enabled + FROM pushers + """ + + sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),) + + txn.execute(sql, list(keyvalues.values())) + + return self.db_pool.cursor_to_dict(txn) + + ret = await self.db_pool.runInteraction( desc="get_pushers_by", + func=get_pushers_by_txn, ) + return self._decode_pushers_rows(ret) - async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: - txn.execute("SELECT * FROM pushers") + async def get_enabled_pushers(self) -> Iterator[PusherConfig]: + def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: + txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - return await self.db_pool.runInteraction("get_all_pushers", get_pushers) + return await self.db_pool.runInteraction( + "get_enabled_pushers", get_enabled_pushers_txn + ) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore): data: Optional[JsonDict], last_stream_ordering: int, profile_tag: str = "", + enabled: bool = True, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore): "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, + "enabled": enabled, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql new file mode 100644 index 0000000000..dba3b4900b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE pushers ADD COLUMN enabled BOOLEAN; \ No newline at end of file diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7a3b0d6755..fd14568f55 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.pusher = self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", @@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase): """ with self.assertRaises(SynapseError) as cm: self.get_success_or_raise( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d9c68cdd2d..af67d84463 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable -from synapse.push import PusherConfigException -from synapse.rest.client import login, push_rule, receipts, room +from synapse.push import PusherConfig, PusherConfigException +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase): login.register_servlets, receipts.register_servlets, push_rule.register_servlets, + pusher.register_servlets, ] user_id = True hijack_auth = False @@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase): def test_data(data: Optional[JsonDict]) -> None: self.get_failure( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + def _make_user_with_pusher( + self, username: str, enabled: bool = True + ) -> Tuple[str, str]: + """Registers a user and creates a pusher for them. + + Args: + username: the localpart of the new user's Matrix ID. + enabled: whether to create the pusher in an enabled or disabled state. + """ user_id = self.register_user(username, "pass") access_token = self.login(username, "pass") # Register the pusher + self._set_pusher(user_id, access_token, enabled) + + return user_id, access_token + + def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None: + """Creates or updates the pusher for the given user. + + Args: + user_id: the user's Matrix ID. + access_token: the access token associated with the pusher. + enabled: whether to enable or disable the pusher. + """ user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase): pushkey="a@example.com", lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, + enabled=enabled, ) ) - return user_id, access_token - def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification @@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase): # The user sends a message back (sends a notification) self.helper.send(room, body="Hello", tok=access_token) self.assertEqual(len(self.push_attempts), 1) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_disable(self) -> None: + """Tests that disabling a pusher means it's not pushed to anymore.""" + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it generated a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Disable the pusher. + self._set_pusher(user_id, access_token, enabled=False) + + # Send another message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as disabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertFalse(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_enable(self) -> None: + """Tests that enabling a disabled pusher means it gets pushed to.""" + # Create the user with the pusher already disabled. + user_id, access_token = self._make_user_with_pusher("user", enabled=False) + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # Enable the pusher. + self._set_pusher(user_id, access_token, enabled=True) + + # Send another message and check that it did generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as enabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertTrue(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_null_enabled(self) -> None: + """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers + created before the column was introduced) is considered enabled. + """ + # We intentionally set 'enabled' to None so that it's stored as NULL in the + # database. + user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type] + + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) + + def test_update_different_device_access_token(self) -> None: + """Tests that if we create a pusher from one device, the update it from another + device, the access token associated with the pusher stays the same. + """ + # Create a user with a pusher. + user_id, access_token = self._make_user_with_pusher("user") + + # Get the token ID for the current access token, since that's what we store in + # the pushers table. + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + # Generate a new access token, and update the pusher with it. + new_token = self.login("user", "pass") + self._set_pusher(user_id, new_token, enabled=False) + + # Get the current list of pushers for the user. + ret = self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) + pushers: List[PusherConfig] = list(ret) + + # Check that we still have one pusher, and that the access token associated with + # it didn't change. + self.assertEqual(len(pushers), 1) + self.assertEqual(pushers[0].access_token, token_id) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 8f4f6688ce..59fea93e49 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): token_id = user_dict.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9f536ceeb3..1847e6ad6b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.other_user, access_token=token_id, kind="http", -- cgit 1.5.1 From 0fd2f2d46064efd37284a36d5b478815d69ddd96 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 21 Sep 2022 16:12:29 +0100 Subject: Implementation of MSC3882 login token request (#13722) --- changelog.d/13722.feature | 1 + synapse/config/experimental.py | 7 ++ synapse/rest/__init__.py | 2 + synapse/rest/client/login_token_request.py | 94 ++++++++++++++++++ synapse/rest/client/versions.py | 2 + tests/rest/client/test_login_token_request.py | 132 ++++++++++++++++++++++++++ 6 files changed, 238 insertions(+) create mode 100644 changelog.d/13722.feature create mode 100644 synapse/rest/client/login_token_request.py create mode 100644 tests/rest/client/test_login_token_request.py (limited to 'synapse') diff --git a/changelog.d/13722.feature b/changelog.d/13722.feature new file mode 100644 index 0000000000..588d143c0f --- /dev/null +++ b/changelog.d/13722.feature @@ -0,0 +1 @@ +Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f4541a8db0..bf27f6c101 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -96,3 +96,10 @@ class ExperimentalConfig(Config): # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) + + # MSC3882: Allow an existing session to sign in a new session + self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False) + self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True) + self.msc3882_token_timeout = self.parse_duration( + experimental.get("msc3882_token_timeout", "5m") + ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index b712215112..9a2ab99ede 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client import ( keys, knock, login as v1_login, + login_token_request, logout, mutual_rooms, notifications, @@ -130,3 +131,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) + login_token_request.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py new file mode 100644 index 0000000000..ca5c54bf17 --- /dev/null +++ b/synapse/rest/client/login_token_request.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginTokenRequestServlet(RestServlet): + """ + Get a token that can be used with `m.login.token` to log in a second device. + + Request: + + POST /login/token HTTP/1.1 + Content-Type: application/json + + {} + + Response: + + HTTP/1.1 200 OK + { + "login_token": "ABDEFGH", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/login/token$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self.server_name = hs.config.server.server_name + self.macaroon_gen = hs.get_macaroon_generator() + self.auth_handler = hs.get_auth_handler() + self.token_timeout = hs.config.experimental.msc3882_token_timeout + self.ui_auth = hs.config.experimental.msc3882_ui_auth + + @interactive_auth_handler + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + + if self.ui_auth: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "issue a new access token for your account", + can_skip_ui_auth=False, # Don't allow skipping of UI auth + ) + + login_token = self.macaroon_gen.generate_short_term_login_token( + user_id=requester.user.to_string(), + auth_provider_id="org.matrix.msc3882.login_token_request", + duration_in_ms=self.token_timeout, + ) + + return ( + 200, + { + "login_token": login_token, + "expires_in": self.token_timeout // 1000, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3882_enabled: + LoginTokenRequestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c516cda95d..c3488f4330 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,6 +105,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, + # Adds support for login token requests as per MSC3882 + "org.matrix.msc3882": self.config.experimental.msc3882_enabled, }, }, ) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py new file mode 100644 index 0000000000..d5bb16c98d --- /dev/null +++ b/tests/rest/client/test_login_token_request.py @@ -0,0 +1,132 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, login_token_request +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + + +class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + admin.register_servlets, + login_token_request.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] + self.hs.config.captcha.enable_registration_captcha = False + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = "user123" + self.password = "password" + + def test_disabled(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 400) + + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_require_auth(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 401) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_uia_on(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 401) + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + session = channel.json_body["session"] + + uia = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.password, + "session": session, + }, + } + + channel = self.make_request("POST", "/login/token", uia, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}} + ) + def test_uia_off(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + { + "experimental_features": { + "msc3882_enabled": True, + "msc3882_ui_auth": False, + "msc3882_token_timeout": "15s", + } + } + ) + def test_expires_in(self) -> None: + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 15) -- cgit 1.5.1 From ccca14140a019c2e0430f95d78fa075efd8d535f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 16:31:53 +0100 Subject: Track device IDs for pushers (#13831) Second half of the MSC3881 implementation --- changelog.d/13831.feature | 1 + synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 10 ++- synapse/rest/client/pusher.py | 3 + synapse/storage/databases/main/pusher.py | 73 +++++++++++++++++++++- .../schema/main/delta/73/03pusher_device_id.sql | 20 ++++++ tests/push/test_http.py | 55 ++++++++++++++-- 7 files changed, 154 insertions(+), 10 deletions(-) create mode 100644 changelog.d/13831.feature create mode 100644 synapse/storage/schema/main/delta/73/03pusher_device_id.sql (limited to 'synapse') diff --git a/changelog.d/13831.feature b/changelog.d/13831.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13831.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index ac99d35a7e..a0c760239d 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -117,6 +117,7 @@ class PusherConfig: last_success: Optional[int] failing_since: Optional[int] enabled: bool + device_id: Optional[str] def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -130,6 +131,7 @@ class PusherConfig: "profile_tag": self.profile_tag, "pushkey": self.pushkey, "enabled": self.enabled, + "device_id": self.device_id, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2597898cf4..e2648cbc93 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -107,6 +107,7 @@ class PusherPool: data: JsonDict, profile_tag: str = "", enabled: bool = True, + device_id: Optional[str] = None, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -149,18 +150,20 @@ class PusherPool: last_success=None, failing_since=None, enabled=enabled, + device_id=device_id, ) ) # Before we actually persist the pusher, we check if the user already has one - # for this app ID and pushkey. If so, we want to keep the access token in place, - # since this could be one device modifying (e.g. enabling/disabling) another - # device's pusher. + # this app ID and pushkey. If so, we want to keep the access token and device ID + # in place, since this could be one device modifying (e.g. enabling/disabling) + # another device's pusher. existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( user_id, app_id, pushkey ) if existing_config: access_token = existing_config.access_token + device_id = existing_config.device_id await self.store.add_pusher( user_id=user_id, @@ -176,6 +179,7 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, enabled=enabled, + device_id=device_id, ) pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index c9f76125dc..975eef2144 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -57,7 +57,9 @@ class PushersRestServlet(RestServlet): for pusher in pusher_dicts: if self._msc3881_enabled: pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + pusher["org.matrix.msc3881.device_id"] = pusher["device_id"] del pusher["enabled"] + del pusher["device_id"] return 200, {"pushers": pusher_dicts} @@ -134,6 +136,7 @@ class PushersSetRestServlet(RestServlet): data=content["data"], profile_tag=content.get("profile_tag", ""), enabled=enabled, + device_id=requester.device_id, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index ee55b8c4a9..01206950a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -124,7 +124,7 @@ class PusherWorkerStore(SQLBaseStore): id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since, - COALESCE(enabled, TRUE) AS enabled + COALESCE(enabled, TRUE) AS enabled, device_id FROM pushers """ @@ -477,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted -class PusherStore(PusherWorkerStore): +class PusherBackgroundUpdatesStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "set_device_id_for_pushers", self._set_device_id_for_pushers + ) + + async def _set_device_id_for_pushers( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to populate the device_id column of the pushers table.""" + last_pusher_id = progress.get("pusher_id", 0) + + def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int: + txn.execute( + """ + SELECT p.id, at.device_id + FROM pushers AS p + INNER JOIN access_tokens AS at + ON p.access_token = at.id + WHERE + p.access_token IS NOT NULL + AND at.device_id IS NOT NULL + AND p.id > ? + ORDER BY p.id + LIMIT ? + """, + (last_pusher_id, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + if len(rows) == 0: + return 0 + + self.db_pool.simple_update_many_txn( + txn=txn, + table="pushers", + key_names=("id",), + key_values=[(row["id"],) for row in rows], + value_names=("device_id",), + value_values=[(row["device_id"],) for row in rows], + ) + + self.db_pool.updates._background_update_progress_txn( + txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]} + ) + + return len(rows) + + nb_processed = await self.db_pool.runInteraction( + "set_device_id_for_pushers", set_device_id_for_pushers_txn + ) + + if nb_processed < batch_size: + await self.db_pool.updates._end_background_update( + "set_device_id_for_pushers" + ) + + return nb_processed + + +class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() @@ -496,6 +563,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering: int, profile_tag: str = "", enabled: bool = True, + device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -515,6 +583,7 @@ class PusherStore(PusherWorkerStore): "profile_tag": profile_tag, "id": stream_id, "enabled": enabled, + "device_id": device_id, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/03pusher_device_id.sql b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql new file mode 100644 index 0000000000..1b4ffbeebe --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a device_id column to track the device ID that created the pusher. It's NULLable +-- on purpose, because a) it might not be possible to track down the device that created +-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and +-- b) access tokens retrieved via the admin API don't have a device associated to them. +ALTER TABLE pushers ADD COLUMN device_id TEXT; \ No newline at end of file diff --git a/tests/push/test_http.py b/tests/push/test_http.py index af67d84463..b383b8401f 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -22,6 +22,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfig, PusherConfigException from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import JsonDict from synapse.util import Clock @@ -771,6 +772,7 @@ class HTTPPusherTests(HomeserverTestCase): lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, enabled=enabled, + device_id=user_tuple.device_id, ) ) @@ -885,19 +887,21 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(channel.json_body["pushers"]), 1) self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) - def test_update_different_device_access_token(self) -> None: + def test_update_different_device_access_token_device_id(self) -> None: """Tests that if we create a pusher from one device, the update it from another - device, the access token associated with the pusher stays the same. + device, the access token and device ID associated with the pusher stays the + same. """ # Create a user with a pusher. user_id, access_token = self._make_user_with_pusher("user") # Get the token ID for the current access token, since that's what we store in - # the pushers table. + # the pushers table. Also get the device ID from it. user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id + device_id = user_tuple.device_id # Generate a new access token, and update the pusher with it. new_token = self.login("user", "pass") @@ -909,7 +913,48 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers: List[PusherConfig] = list(ret) - # Check that we still have one pusher, and that the access token associated with - # it didn't change. + # Check that we still have one pusher, and that the access token and device ID + # associated with it didn't change. self.assertEqual(len(pushers), 1) self.assertEqual(pushers[0].access_token, token_id) + self.assertEqual(pushers[0].device_id, device_id) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_device_id(self) -> None: + """Tests that a pusher created with a given device ID shows that device ID in + GET /pushers requests. + """ + self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # We create the pusher with an HTTP request rather than with + # _make_user_with_pusher so that we can test the device ID is correctly set when + # creating a pusher via an API call. + self.make_request( + method="POST", + path="/pushers/set", + content={ + "kind": "http", + "app_id": "m.http", + "app_display_name": "HTTP Push Notifications", + "device_display_name": "pushy push", + "pushkey": "a@example.com", + "lang": "en", + "data": {"url": "http://example.com/_matrix/push/v1/notify"}, + }, + access_token=access_token, + ) + + # Look up the user info for the access token so we can compare the device ID. + lookup_result: TokenLookupResult = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + + # Get the user's devices and check it has the correct device ID. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertEqual( + channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"], + lookup_result.device_id, + ) -- cgit 1.5.1 From efabf44c7652095a0e3d9d9083fc8359cdde3854 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 17:18:44 +0100 Subject: Add version flag for MSC3881 (#13860) --- changelog.d/13860.feature | 1 + synapse/rest/client/versions.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/13860.feature (limited to 'synapse') diff --git a/changelog.d/13860.feature b/changelog.d/13860.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13860.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c3488f4330..b3917a5abc 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -107,6 +107,8 @@ class VersionsRestServlet(RestServlet): "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 "org.matrix.msc3882": self.config.experimental.msc3882_enabled, + # Adds support for remotely enabling/disabling pushers, as per MSC3881 + "org.matrix.msc3881": self.config.experimental.msc3881_enabled, }, }, ) -- cgit 1.5.1 From 1a1abdda42551dad3aadc04a169c25f4cc651a2c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 21 Sep 2022 22:23:44 +0100 Subject: Last batch of Pydantic for synapse/rest/client/account.py (#13832) * Validation for `/add_threepid/msisdn/submit_token` * Don't validate deprecated endpoint * Changelog --- changelog.d/13832.feature | 1 + synapse/rest/client/account.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13832.feature (limited to 'synapse') diff --git a/changelog.d/13832.feature b/changelog.d/13832.feature new file mode 100644 index 0000000000..1dc1d66efe --- /dev/null +++ b/changelog.d/13832.feature @@ -0,0 +1 @@ +Improve validation for the unspecced, internal-only `_matrix/client/unstable/add_threepid/msisdn/submit_token` endpoint. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 2db2a04f95..44f622bcce 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -534,6 +534,11 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "/add_threepid/msisdn/submit_token$", releases=(), unstable=True ) + class PostBody(RequestBodyModel): + client_secret: ClientSecretStr + sid: StrictStr + token: StrictStr + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config @@ -549,16 +554,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "instead.", ) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "sid", "token"]) - assert_valid_client_secret(body["client_secret"]) + body = parse_and_validate_json_object_from_request(request, self.PostBody) # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( self.config.registration.account_threepid_delegate_msisdn, - body["client_secret"], - body["sid"], - body["token"], + body.client_secret, + body.sid, + body.token, ) return 200, response @@ -581,6 +584,10 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} + # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because + # the endpoint is deprecated. (If you really want to, you could do this by reusing + # ThreePidBindRestServelet.PostBody with an `alias_generator` to handle + # `threePidCreds` versus `three_pid_creds`. async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( -- cgit 1.5.1 From b7272b73aa38dcb19c9b075514f963390358113d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 22 Sep 2022 08:47:49 -0400 Subject: Properly paginate forward in the /relations API. (#13840) This fixes a bug where the `/relations` API with `dir=f` would skip the first item of each page (except the first page), causing incomplete data to be returned to the client. --- changelog.d/13840.bugfix | 1 + synapse/storage/databases/main/relations.py | 38 +++++++++++++++++++++-------- synapse/storage/databases/main/stream.py | 6 ++--- tests/rest/client/test_relations.py | 29 +++++++++++++++++++++- 4 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13840.bugfix (limited to 'synapse') diff --git a/changelog.d/13840.bugfix b/changelog.d/13840.bugfix new file mode 100644 index 0000000000..0f014439a8 --- /dev/null +++ b/changelog.d/13840.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790eb..898947af95 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -51,6 +51,8 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str + topological_ordering: Optional[int] + stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -91,6 +93,9 @@ class RelationsWorkerStore(SQLBaseStore): # it. The `event_id` must match the `event.event_id`. assert event.event_id == event_id + # Ensure bad limits aren't being passed in. + assert limit >= 0 + where_clause = ["relates_to_id = ?", "room_id = ?"] where_args: List[Union[str, int]] = [event.event_id, room_id] is_redacted = event.internal_metadata.is_redacted() @@ -139,21 +144,34 @@ class RelationsWorkerStore(SQLBaseStore): ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) - last_topo_id = None - last_stream_id = None events = [] - for row in txn: + for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: # Do not include edits for redacted events as they leak event # content. - if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append(_RelatedEvent(row[0], row[2])) - last_topo_id = row[3] - last_stream_id = row[4] + if not is_redacted or relation_type != RelationTypes.REPLACE: + events.append( + _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) + ) - # If there are more events, generate the next pagination key. + # If there are more events, generate the next pagination key from the + # last event returned. next_token = None - if len(events) > limit and last_topo_id and last_stream_id: - next_key = RoomStreamToken(last_topo_id, last_stream_id) + if len(events) > limit: + # Instead of using the last row (which tells us there is more + # data), use the last row to be returned. + events = events[:limit] + + topo = events[-1].topological_ordering + token = events[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + token -= 1 + next_key = RoomStreamToken(topo, token) + if from_token: next_token = from_token.copy_and_replace( StreamKeyType.ROOM, next_key diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 3f9bfaeac5..530f04e149 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if rows: topo = rows[-1].topological_ordering - toke = rows[-1].stream_ordering + token = rows[-1].stream_ordering if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. - toke -= 1 - next_token = RoomStreamToken(topo, toke) + token -= 1 + next_token = RoomStreamToken(topo, token) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 651f4f415d..d33e34d829 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) + # Test forward pagination. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(found_event_ids, expected_event_ids) + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") -- cgit 1.5.1 From c06b2b714262825e1d2510b62c38fdeda339f6dc Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 23 Sep 2022 10:47:16 +0000 Subject: Faster Remote Room Joins: tell remote homeservers that we are unable to authorise them if they query a room which has partial state on our server. (#13823) --- changelog.d/13823.misc | 1 + synapse/api/errors.py | 6 ++++++ synapse/config/experimental.py | 3 ++- synapse/federation/federation_server.py | 11 +++-------- synapse/handlers/event_auth.py | 31 ++++++++++++++++++++++++++---- synapse/handlers/federation.py | 34 +++++++++++++-------------------- synapse/handlers/federation_event.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/room_summary.py | 6 ++---- synapse/handlers/typing.py | 2 +- tests/handlers/test_typing.py | 2 +- 11 files changed, 58 insertions(+), 42 deletions(-) create mode 100644 changelog.d/13823.misc (limited to 'synapse') diff --git a/changelog.d/13823.misc b/changelog.d/13823.misc new file mode 100644 index 0000000000..527d79f4b2 --- /dev/null +++ b/changelog.d/13823.misc @@ -0,0 +1 @@ +Faster Remote Room Joins: tell remote homeservers that we are unable to authorise them if they query a room which has partial state on our server. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e6dea89c6d..1c6b53aa24 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -100,6 +100,12 @@ class Codes(str, Enum): UNREDACTED_CONTENT_DELETED = "FI.MAU.MSC2815_UNREDACTED_CONTENT_DELETED" + # Returned for federation requests where we can't process a request as we + # can't ensure the sending server is in a room which is partial-stated on + # our side. + # Part of MSC3895. + UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE" + class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index bf27f6c101..595eb007a5 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -63,7 +63,8 @@ class ExperimentalConfig(Config): # MSC3706 (server-side support for partial state in /send_join responses) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) - # experimental support for faster joins over federation (msc2775, msc3706) + # experimental support for faster joins over federation + # (MSC2775, MSC3706, MSC3895) # requires a target server with msc3706_enabled enabled. self.faster_joins_enabled: bool = experimental.get("faster_joins", False) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3bf84cf625..907940e19e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -530,13 +530,10 @@ class FederationServer(FederationBase): async def on_room_state_request( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, JsonDict]: + await self._event_auth_handler.assert_host_in_room(room_id, origin) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - # we grab the linearizer to protect ourselves from servers which hammer # us. In theory we might already have the response to this query # in the cache so we could return it without waiting for the linearizer @@ -560,13 +557,10 @@ class FederationServer(FederationBase): if not event_id: raise NotImplementedError("Specify an event") + await self._event_auth_handler.assert_host_in_room(room_id, origin) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - resp = await self._state_ids_resp_cache.wrap( (room_id, event_id), self._on_state_ids_request_compute, @@ -955,6 +949,7 @@ class FederationServer(FederationBase): self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: async with self._server_linearizer.queue((origin, room_id)): + await self._event_auth_handler.assert_host_in_room(room_id, origin) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index c3ddc5d182..8249ca1ed2 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -31,7 +31,6 @@ from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.types import StateMap, get_domain_from_id -from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -156,9 +155,33 @@ class EventAuthHandler: Codes.UNABLE_TO_GRANT_JOIN, ) - async def check_host_in_room(self, room_id: str, host: str) -> bool: - with Measure(self._clock, "check_host_in_room"): - return await self._store.is_host_joined(room_id, host) + async def is_host_in_room(self, room_id: str, host: str) -> bool: + return await self._store.is_host_joined(room_id, host) + + async def assert_host_in_room( + self, room_id: str, host: str, allow_partial_state_rooms: bool = False + ) -> None: + """ + Asserts that the host is in the room, or raises an AuthError. + + If the room is partial-stated, we raise an AuthError with the + UNABLE_DUE_TO_PARTIAL_STATE error code, unless `allow_partial_state_rooms` is true. + + If allow_partial_state_rooms is True and the room is partial-stated, + this function may return an incorrect result as we are not able to fully + track server membership in a room without full state. + """ + if not allow_partial_state_rooms and await self._store.is_partial_state_room( + room_id + ): + raise AuthError( + 403, + "Unable to authorise you right now; room is partial-stated here.", + errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, + ) + + if not await self.is_host_in_room(room_id, host): + raise AuthError(403, "Host not in room.") async def check_restricted_join_rules( self, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dd4b9f66d1..583d5ecd77 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -804,7 +804,7 @@ class FederationHandler: ) # now check that we are *still* in the room - is_in_room = await self._event_auth_handler.check_host_in_room( + is_in_room = await self._event_auth_handler.is_host_in_room( room_id, self.server_name ) if not is_in_room: @@ -1150,9 +1150,7 @@ class FederationHandler: async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") + await self._event_auth_handler.assert_host_in_room(room_id, origin) # Synapse asks for 100 events per backfill request. Do not allow more. limit = min(limit, 100) @@ -1198,21 +1196,17 @@ class FederationHandler: event_id, allow_none=True, allow_rejected=True ) - if event: - in_room = await self._event_auth_handler.check_host_in_room( - event.room_id, origin - ) - if not in_room: - raise AuthError(403, "Host not in room.") - - events = await filter_events_for_server( - self._storage_controllers, origin, [event] - ) - event = events[0] - return event - else: + if not event: return None + await self._event_auth_handler.assert_host_in_room(event.room_id, origin) + + events = await filter_events_for_server( + self._storage_controllers, origin, [event] + ) + event = events[0] + return event + async def on_get_missing_events( self, origin: str, @@ -1221,9 +1215,7 @@ class FederationHandler: latest_events: List[str], limit: int, ) -> List[EventBase]: - in_room = await self._event_auth_handler.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") + await self._event_auth_handler.assert_host_in_room(room_id, origin) # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) @@ -1257,7 +1249,7 @@ class FederationHandler: "state_key": target_user_id, } - if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname): + if await self._event_auth_handler.is_host_in_room(room_id, self.hs.hostname): room_version_obj = await self.store.get_room_version(room_id) builder = self.event_builder_factory.for_room_version( room_version_obj, event_dict diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index efcdb84057..2d7cde7506 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -238,7 +238,7 @@ class FederationEventHandler: # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = await self._event_auth_handler.check_host_in_room( + is_in_room = await self._event_auth_handler.is_host_in_room( room_id, self._server_name ) if not is_in_room: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index d2bdb9c8be..afaf3261df 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -70,7 +70,7 @@ class ReceiptsHandler: # If we're not in the room just ditch the event entirely. This is # probably an old server that has come back and thinks we're still in # the room (or we've been rejoined to the room by a state reset). - is_in_room = await self.event_auth_handler.check_host_in_room( + is_in_room = await self.event_auth_handler.is_host_in_room( room_id, self.server_name ) if not is_in_room: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index ebd445adca..8d08625237 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -609,7 +609,7 @@ class RoomSummaryHandler: # If this is a request over federation, check if the host is in the room or # has a user who could join the room. elif origin: - if await self._event_auth_handler.check_host_in_room( + if await self._event_auth_handler.is_host_in_room( room_id, origin ) or await self._store.is_host_invited(room_id, origin): return True @@ -624,9 +624,7 @@ class RoomSummaryHandler: await self._event_auth_handler.get_rooms_that_allow_join(state_ids) ) for space_id in allowed_rooms: - if await self._event_auth_handler.check_host_in_room( - space_id, origin - ): + if await self._event_auth_handler.is_host_in_room(space_id, origin): return True logger.info( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a4cd8b8f0c..0d8466af11 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -340,7 +340,7 @@ class TypingWriterHandler(FollowerTypingHandler): # If we're not in the room just ditch the event entirely. This is # probably an old server that has come back and thinks we're still in # the room (or we've been rejoined to the room by a state reset). - is_in_room = await self.event_auth_handler.check_host_in_room( + is_in_room = await self.event_auth_handler.is_host_in_room( room_id, self.server_name ) if not is_in_room: diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 8adba29d7f..1a247f12e8 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,7 +129,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().check_host_in_room = check_host_in_room + hs.get_event_auth_handler().is_host_in_room = check_host_in_room async def get_current_hosts_in_room(room_id: str): return {member.domain for member in self.room_members} -- cgit 1.5.1 From 03c2bfb7f89d637930da52723161ce74d4f89233 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 23 Sep 2022 13:44:03 +0100 Subject: Send device list updates out to servers in partially joined rooms (#13874) Use the provided list of servers in the room from the `/send_join` response, since we will not know which users are in the room. This isn't sufficient to ensure that all remote servers receive the right device list updates, since the `/send_join` response may be inaccurate or we may calculate the membership state of new users in the room incorrectly. Signed-off-by: Sean Quah --- changelog.d/13874.misc | 1 + synapse/handlers/device.py | 6 ++++- synapse/storage/controllers/state.py | 44 +++++++++++++++++++++++++++++++++- synapse/storage/databases/main/room.py | 17 +++++++++++++ 4 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13874.misc (limited to 'synapse') diff --git a/changelog.d/13874.misc b/changelog.d/13874.misc new file mode 100644 index 0000000000..499e488c35 --- /dev/null +++ b/changelog.d/13874.misc @@ -0,0 +1 @@ +Faster room joins: Send device list updates to most servers in rooms with partial state. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 901e2310b7..6566b3bf3d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -688,11 +688,15 @@ class DeviceHandler(DeviceWorkerHandler): # Ignore any users that aren't ours if self.hs.is_mine_id(user_id): hosts = set( - await self._storage_controllers.state.get_current_hosts_in_room( + await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) ) hosts.discard(self.server_name) + # For rooms with partial state, `hosts` is merely an + # approximation. When we transition to a full state room, we + # will have to send out device list updates to any servers we + # missed. # Check if we've already sent this update to some hosts if current_stream_id == stream_id: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index bbe568bf05..b1aa17047c 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -23,6 +23,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, ) @@ -524,12 +525,53 @@ class StateStorageController: return state_map.get(key) async def get_current_hosts_in_room(self, room_id: str) -> List[str]: - """Get current hosts in room based on current state.""" + """Get current hosts in room based on current state. + + Blocks until we have full state for the given room. This only happens for rooms + with partial state. + + Returns: + A list of hosts in the room, sorted by longest in the room first. (aka. + sorted by join with the lowest depth first). + """ await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_current_hosts_in_room(room_id) + async def get_current_hosts_in_room_or_partial_state_approximation( + self, room_id: str + ) -> Sequence[str]: + """Get approximation of current hosts in room based on current state. + + For rooms with full state, this is equivalent to `get_current_hosts_in_room`, + with the same order of results. + + For rooms with partial state, no blocking occurs. Instead, the list of hosts + in the room at the time of joining is combined with the list of hosts which + joined the room afterwards. The returned list may include hosts that are not + actually in the room and exclude hosts that are in the room, since we may + calculate state incorrectly during the partial state phase. The order of results + is arbitrary for rooms with partial state. + """ + # We have to read this list first to mitigate races with un-partial stating. + # This will be empty for rooms with full state. + hosts_at_join = await self.stores.main.get_partial_state_servers_at_join( + room_id + ) + + hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id) + hosts_from_state_set = set(hosts_from_state) + + # First take the list of hosts based on the current state. + # For rooms with partial state, this will be missing most hosts. + hosts = list(hosts_from_state) + # Then add in the list of hosts in the room at the time we joined. + # This will be an empty list for rooms with full state. + hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set) + + return hosts + async def get_users_in_room_with_profiles( self, room_id: str ) -> Dict[str, ProfileInfo]: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index bef66f1992..5dd116d766 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -25,6 +25,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, Union, cast, @@ -1133,6 +1134,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_rooms_for_retention_period_in_range_txn, ) + async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]: + """Gets the list of servers in a partial state room at the time we joined it. + + Returns: + The `servers_in_room` list from the `/send_join` response for partial state + rooms. May not be accurate or complete, as it comes from a remote + homeserver. + An empty list for full state rooms. + """ + return await self.db_pool.simple_select_onecol( + "partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + retcol="server_name", + desc="get_partial_state_servers_at_join", + ) + async def get_partial_state_rooms_and_servers( self, ) -> Mapping[str, Collection[str]]: -- cgit 1.5.1 From efd108b45d1706526416bc9a6f89463b5ff4506a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 23 Sep 2022 10:33:28 -0400 Subject: Accept & store thread IDs for receipts (implement MSC3771). (#13782) Updates the `/receipts` endpoint and receipt EDU handler to parse a `thread_id` from the body and insert it in the database. --- changelog.d/13782.feature | 1 + synapse/config/experimental.py | 2 + synapse/handlers/receipts.py | 23 ++++++- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/streams/_base.py | 1 + synapse/rest/client/read_marker.py | 2 + synapse/rest/client/receipts.py | 14 ++++- synapse/rest/client/versions.py | 2 + synapse/storage/database.py | 2 + synapse/storage/databases/main/receipts.py | 87 +++++++++++++++++++------- synapse/types.py | 1 + tests/federation/test_federation_sender.py | 21 ++++++- tests/handlers/test_appservice.py | 1 + tests/replication/slave/storage/test_events.py | 2 +- tests/replication/tcp/streams/test_receipts.py | 15 ++++- tests/storage/test_event_push_actions.py | 1 + tests/storage/test_receipts.py | 36 ++++++++--- 17 files changed, 173 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13782.feature (limited to 'synapse') diff --git a/changelog.d/13782.feature b/changelog.d/13782.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13782.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 595eb007a5..933779c23a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -83,6 +83,8 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + # MSC3771: Thread read receipts + self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index afaf3261df..4768a34c07 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,6 +63,8 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() + self._msc3771_enabled = hs.config.experimental.msc3771_enabled + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -91,13 +93,23 @@ class ReceiptsHandler: ) continue + # Check if these receipts apply to a thread. + thread_id = None + data = user_values.get("data", {}) + if self._msc3771_enabled and isinstance(data, dict): + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None + receipts.append( ReadReceipt( room_id=room_id, receipt_type=receipt_type, user_id=user_id, event_ids=user_values["event_ids"], - data=user_values.get("data", {}), + thread_id=thread_id, + data=data, ) ) @@ -114,6 +126,7 @@ class ReceiptsHandler: receipt.receipt_type, receipt.user_id, receipt.event_ids, + receipt.thread_id, receipt.data, ) @@ -146,7 +159,12 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -156,6 +174,7 @@ class ReceiptsHandler: receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], + thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cf9cd6833b..b2522f98ca 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -427,7 +427,8 @@ class FederationSenderHandler: receipt.receipt_type, receipt.user_id, [receipt.event_id], - receipt.data, + thread_id=receipt.thread_id, + data=receipt.data, ) await self.federation_sender.send_read_receipt(receipt_info) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 398bebeaa6..e01155ad59 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -361,6 +361,7 @@ class ReceiptsStream(Stream): receipt_type: str user_id: str event_id: str + thread_id: Optional[str] data: dict NAME = "receipts" diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 5e53096539..852838515c 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -83,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + # Setting the thread ID is not possible with the /read_markers endpoint. + thread_id=None, ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 5b7fad7402..f3ff156abe 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,6 +49,7 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } + self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -61,7 +62,17 @@ class ReceiptRestServlet(RestServlet): f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - parse_json_object_from_request(request, allow_empty_body=False) + body = parse_json_object_from_request(request) + + # Pull the thread ID, if one exists. + thread_id = None + if self._msc3771_enabled: + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, "thread_id field must be a non-empty string" + ) await self.presence_handler.bump_presence_active_time(requester.user) @@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + thread_id=thread_id, ) return 200, {} diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index b3917a5abc..c95b0d6f19 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,6 +103,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above + # Support for thread read receipts. + "org.matrix.msc3771": self.config.experimental.msc3771_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 921cd4dc5e..9d116f6925 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -95,6 +95,8 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "event_push_summary": "event_push_summary_unique_index", + "receipts_linearized": "receipts_linearized_unique_index", + "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ddb8e80b69..52fe0db924 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool + ]: """Get updates for receipts replication stream. Args: @@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_all_updated_receipts_txn( txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + int, + bool, + ]: sql = """ - SELECT stream_id, room_id, receipt_type, user_id, event_id, data + SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) updates = cast( - List[Tuple[int, list]], - [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn], ) limited = False @@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_id: str, + thread_id: Optional[str], data: JsonDict, stream_id: int, ) -> Optional[int]: @@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore): # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + if thread_id is None: + thread_clause = "r.thread_id IS NULL" + thread_args: Tuple[str, ...] = () + else: + thread_clause = "r.thread_id = ?" + thread_args = (thread_id,) + + sql = f""" + SELECT stream_ordering, event_id FROM events + INNER JOIN receipts_linearized AS r USING (event_id, room_id) + WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause} + """ + txn.execute( + sql, + ( + room_id, + receipt_type, + user_id, + ) + + thread_args, ) - txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: @@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "stream_id": stream_id, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, @@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. @@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, linearized_event_id, + thread_id, data, stream_id=stream_id, # Read committed is actually beneficial here because we check for a receipt with @@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore): now = self._clock.time_msec() logger.debug( - "RR for event %s in %s (%i ms old)", + "Receipt %s for event %s in %s (%i ms old)", + receipt_type, linearized_event_id, room_id, now - event_ts, @@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, event_ids, + thread_id, data, ) @@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts @@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, diff --git a/synapse/types.py b/synapse/types.py index ec44601f54..773f0438d5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -835,6 +835,7 @@ class ReadReceipt: receipt_type: str user_id: str event_ids: List[str] + thread_id: Optional[str] data: JsonDict diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5aa500ef8..f1e357764f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): # send the second RR receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["other_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b17af2725b..af24c4984d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipt_type="m.read", user_id=self.local_user, event_ids=[f"$eventid_{i}"], + thread_id=None, data={}, ) ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 49a21e2e85..efd92793c0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -171,7 +171,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if send_receipt: self.get_success( self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} ) ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index eb00117845..ede6d0c118 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} + "!room:blue", + "m.read", + USER_ID, + ["$event:blue"], + thread_id=None, + data={"a": 1}, ) ) self.replicate() @@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) + self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) # Now let's disconnect and insert some data. @@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + "!room2:blue", + "m.read", + USER_ID, + ["$event2:foo"], + thread_id=None, + data={"a": 2}, ) ) self.replicate() diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index fc43d7edd1..08c74b93e3 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -106,6 +106,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "m.read", user_id=user_id, event_ids=[event_id], + thread_id=None, data={}, ) ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index c89bfff241..9459ee1705 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -131,13 +131,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -164,7 +169,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -180,7 +185,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( @@ -202,13 +212,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -241,7 +256,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -259,7 +274,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( -- cgit 1.5.1 From db868db594c1a8a0baa3686b60f1c49c0d4be371 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 23 Sep 2022 11:49:39 -0500 Subject: Fix access token leak to logs from proxyagent (#13855) This can happen specifically with an application service `/transactions/10722?access_token=leaked` request Fix https://github.com/matrix-org/synapse/issues/13010 --- Saw an example leak in https://github.com/matrix-org/synapse/issues/13423#issuecomment-1205348482 ``` 2022-08-04 14:47:57,925 - synapse.http.client - 401 - DEBUG - as-sender-signal-1 - Sending request PUT http://localhost:29328/transactions/10722?access_token= 2022-08-04 14:47:57,926 - synapse.http.proxyagent - 223 - DEBUG - as-sender-signal-1 - Requesting b'http://localhost:29328/transactions/10722?access_token=leaked' via ``` --- changelog.d/13855.bugfix | 1 + synapse/http/proxyagent.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13855.bugfix (limited to 'synapse') diff --git a/changelog.d/13855.bugfix b/changelog.d/13855.bugfix new file mode 100644 index 0000000000..5ea8539bd8 --- /dev/null +++ b/changelog.d/13855.bugfix @@ -0,0 +1 @@ +Fix access token leak to logs from proxy agent. diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index b2a50c9105..1f8227896f 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -36,6 +36,7 @@ from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS +from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials from synapse.types import ISynapseReactor @@ -220,7 +221,11 @@ class ProxyAgent(_AgentBase): self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs ) - logger.debug("Requesting %s via %s", uri, endpoint) + logger.debug( + "Requesting %s via %s", + redact_uri(uri.decode("ascii", errors="replace")), + endpoint, + ) if parsed_uri.scheme == b"https": tls_connection_creator = self._policy_for_https.creatorForNetloc( -- cgit 1.5.1 From f49f73c0da5502792c65d3de1ffd352ceb6af562 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 23 Sep 2022 17:55:15 +0100 Subject: Faster room joins: Avoid blocking `/keys/changes` (#13888) Part of the work for #12993. Once #12993 is fully resolved, we expect `/keys/changes` to behave sensibly when joined to a room with partial state. Signed-off-by: Sean Quah --- changelog.d/13888.misc | 1 + synapse/handlers/device.py | 7 +++++-- synapse/storage/controllers/state.py | 7 ++++++- 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13888.misc (limited to 'synapse') diff --git a/changelog.d/13888.misc b/changelog.d/13888.misc new file mode 100644 index 0000000000..4ffd9bcede --- /dev/null +++ b/changelog.d/13888.misc @@ -0,0 +1 @@ +Faster room joins: Avoid waiting for full state when processing `/keys/changes` requests. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6566b3bf3d..bad262731c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -195,7 +195,9 @@ class DeviceWorkerHandler: possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = await self._state_storage.get_current_state_ids(room_id) + current_state_ids = await self._state_storage.get_current_state_ids( + room_id, await_full_state=False + ) # The user may have left the room # TODO: Check if they actually did or if we were just invited. @@ -234,7 +236,8 @@ class DeviceWorkerHandler: # mapping from event_id -> state_dict prev_state_ids = await self._state_storage.get_state_ids_for_events( - event_ids + event_ids, + await_full_state=False, ) # Check if we've joined the room? If so we just blindly add all the users to diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index b1aa17047c..bb60130afe 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -407,6 +407,7 @@ class StateStorageController: self, room_id: str, state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, on_invalidate: Optional[Callable[[], None]] = None, ) -> StateMap[str]: """Get the current state event ids for a room based on the @@ -419,13 +420,17 @@ class StateStorageController: room_id: The room to get the state IDs of. state_filter: The state filter used to fetch state from the database. + await_full_state: if true, will block if we do not yet have complete + state for the room. on_invalidate: Callback for when the `get_current_state_ids` cache for the room gets invalidated. Returns: The current state of the room. """ - if not state_filter or state_filter.must_await_full_state(self._is_mine_id): + if await_full_state and ( + not state_filter or state_filter.must_await_full_state(self._is_mine_id) + ): await self._partial_state_room_tracker.await_full_state(room_id) if state_filter and not state_filter.is_full(): -- cgit 1.5.1 From ac1a31740b6d0dfda4d57a25762aaddfde981caf Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 23 Sep 2022 14:01:29 -0500 Subject: Only try to backfill event if we haven't tried before recently (#13635) Only try to backfill event if we haven't tried before recently (exponential backoff). No need to keep trying the same backfill point that fails over and over. Fix https://github.com/matrix-org/synapse/issues/13622 Fix https://github.com/matrix-org/synapse/issues/8451 Follow-up to https://github.com/matrix-org/synapse/pull/13589 Part of https://github.com/matrix-org/synapse/issues/13356 --- changelog.d/13635.feature | 1 + synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/event_federation.py | 188 ++++++-- tests/storage/test_event_federation.py | 481 ++++++++++++++++++++- 4 files changed, 626 insertions(+), 48 deletions(-) create mode 100644 changelog.d/13635.feature (limited to 'synapse') diff --git a/changelog.d/13635.feature b/changelog.d/13635.feature new file mode 100644 index 0000000000..d86bf7ed80 --- /dev/null +++ b/changelog.d/13635.feature @@ -0,0 +1 @@ +Exponentially backoff from backfilling the same event over and over. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 583d5ecd77..e1a4265a64 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -226,9 +226,7 @@ class FederationHandler: """ backwards_extremities = [ _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) - for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room( - room_id - ) + for event_id, depth in await self.store.get_backfill_points_in_room(room_id) ] insertion_events_to_be_backfilled: List[_BackfillPoint] = [] diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ef477978ed..3251fca6fb 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import itertools import logging from queue import Empty, PriorityQueue @@ -43,7 +44,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -72,6 +73,13 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) +BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int( + datetime.timedelta(days=7).total_seconds() +) +BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int( + datetime.timedelta(hours=1).total_seconds() +) + # All the info we need while iterating the DAG while backfilling @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -715,96 +723,189 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @trace @tag_args - async def get_oldest_event_ids_with_depth_in_room( - self, room_id: str + async def get_backfill_points_in_room( + self, + room_id: str, ) -> List[Tuple[str, int]]: - """Gets the oldest events(backwards extremities) in the room along with the - aproximate depth. - - We use this function so that we can compare and see if someones current - depth at their current scrollback is within pagination range of the - event extremeties. If the current depth is close to the depth of given - oldest event, we can trigger a backfill. + """ + Gets the oldest events(backwards extremities) in the room along with the + approximate depth. Sorted by depth, highest to lowest (descending). Args: room_id: Room where we want to find the oldest events Returns: - List of (event_id, depth) tuples + List of (event_id, depth) tuples. Sorted by depth, highest to lowest + (descending) """ - def get_oldest_event_ids_with_depth_in_room_txn( + def get_backfill_points_in_room_txn( txn: LoggingTransaction, room_id: str ) -> List[Tuple[str, int]]: - # Assemble a dictionary with event_id -> depth for the oldest events + # Assemble a tuple lookup of event_id -> depth for the oldest events # we know of in the room. Backwards extremeties are the oldest # events we know of in the room but we only know of them because - # some other event referenced them by prev_event and aren't peristed - # in our database yet (meaning we don't know their depth - # specifically). So we need to look for the aproximate depth from + # some other event referenced them by prev_event and aren't + # persisted in our database yet (meaning we don't know their depth + # specifically). So we need to look for the approximate depth from # the events connected to the current backwards extremeties. sql = """ - SELECT b.event_id, MAX(e.depth) FROM events as e + SELECT backward_extrem.event_id, event.depth FROM events AS event /** * Get the edge connections from the event_edges table * so we can see whether this event's prev_events points * to a backward extremity in the next join. */ - INNER JOIN event_edges as g - ON g.event_id = e.event_id + INNER JOIN event_edges AS edge + ON edge.event_id = event.event_id /** * We find the "oldest" events in the room by looking for * events connected to backwards extremeties (oldest events * in the room that we know of so far). */ - INNER JOIN event_backward_extremities as b - ON g.prev_event_id = b.event_id - WHERE b.room_id = ? AND g.is_state is ? - GROUP BY b.event_id + INNER JOIN event_backward_extremities AS backward_extrem + ON edge.prev_event_id = backward_extrem.event_id + /** + * We use this info to make sure we don't retry to use a backfill point + * if we've already attempted to backfill from it recently. + */ + LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info + ON + failed_backfill_attempt_info.room_id = backward_extrem.room_id + AND failed_backfill_attempt_info.event_id = backward_extrem.event_id + WHERE + backward_extrem.room_id = ? + /* We only care about non-state edges because we used to use + * `event_edges` for two different sorts of "edges" (the current + * event DAG, but also a link to the previous state, for state + * events). These legacy state event edges can be distinguished by + * `is_state` and are removed from the codebase and schema but + * because the schema change is in a background update, it's not + * necessarily safe to assume that it will have been completed. + */ + AND edge.is_state is ? /* False */ + /** + * Exponential back-off (up to the upper bound) so we don't retry the + * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + * + * We use `1 << n` as a power of 2 equivalent for compatibility + * with older SQLites. The left shift equivalent only works with + * powers of 2 because left shift is a binary operation (base-2). + * Otherwise, we would use `power(2, n)` or the power operator, `2^n`. + */ + AND ( + failed_backfill_attempt_info.event_id IS NULL + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + ) + /** + * Sort from highest to the lowest depth. Then tie-break on + * alphabetical order of the event_ids so we get a consistent + * ordering which is nice when asserting things in tests. + */ + ORDER BY event.depth DESC, backward_extrem.event_id DESC """ - txn.execute(sql, (room_id, False)) + if isinstance(self.database_engine, PostgresEngine): + least_function = "least" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "min" + else: + raise RuntimeError("Unknown database engine") + + txn.execute( + sql % (least_function,), + ( + room_id, + False, + self._clock.time_msec(), + 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, + 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + ), + ) return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( - "get_oldest_event_ids_with_depth_in_room", - get_oldest_event_ids_with_depth_in_room_txn, + "get_backfill_points_in_room", + get_backfill_points_in_room_txn, room_id, ) @trace async def get_insertion_event_backward_extremities_in_room( - self, room_id: str + self, + room_id: str, ) -> List[Tuple[str, int]]: - """Get the insertion events we know about that we haven't backfilled yet. - - We use this function so that we can compare and see if someones current - depth at their current scrollback is within pagination range of the - insertion event. If the current depth is close to the depth of given - insertion event, we can trigger a backfill. + """ + Get the insertion events we know about that we haven't backfilled yet + along with the approximate depth. Sorted by depth, highest to lowest + (descending). Args: room_id: Room where we want to find the oldest events Returns: - List of (event_id, depth) tuples + List of (event_id, depth) tuples. Sorted by depth, highest to lowest + (descending) """ def get_insertion_event_backward_extremities_in_room_txn( txn: LoggingTransaction, room_id: str ) -> List[Tuple[str, int]]: sql = """ - SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + SELECT + insertion_event_extremity.event_id, event.depth /* We only want insertion events that are also marked as backwards extremities */ - INNER JOIN insertion_event_extremities as b USING (event_id) + FROM insertion_event_extremities AS insertion_event_extremity /* Get the depth of the insertion event from the events table */ - INNER JOIN events AS e USING (event_id) - WHERE b.room_id = ? - GROUP BY b.event_id + INNER JOIN events AS event USING (event_id) + /** + * We use this info to make sure we don't retry to use a backfill point + * if we've already attempted to backfill from it recently. + */ + LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info + ON + failed_backfill_attempt_info.room_id = insertion_event_extremity.room_id + AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id + WHERE + insertion_event_extremity.room_id = ? + /** + * Exponential back-off (up to the upper bound) so we don't retry the + * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc + * + * We use `1 << n` as a power of 2 equivalent for compatibility + * with older SQLites. The left shift equivalent only works with + * powers of 2 because left shift is a binary operation (base-2). + * Otherwise, we would use `power(2, n)` or the power operator, `2^n`. + */ + AND ( + failed_backfill_attempt_info.event_id IS NULL + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + ) + /** + * Sort from highest to the lowest depth. Then tie-break on + * alphabetical order of the event_ids so we get a consistent + * ordering which is nice when asserting things in tests. + */ + ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC """ - txn.execute(sql, (room_id,)) + if isinstance(self.database_engine, PostgresEngine): + least_function = "least" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "min" + else: + raise RuntimeError("Unknown database engine") + + txn.execute( + sql % (least_function,), + ( + room_id, + self._clock.time_msec(), + 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, + 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + ), + ) return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( @@ -1539,7 +1640,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas self, room_id: str, ) -> Optional[Tuple[str, str]]: - """Get the next event ID in the staging area for the given room.""" + """ + Get the next event ID in the staging area for the given room. + + Returns: + Tuple of the `origin` and `event_id` + """ def _get_next_staged_event_id_for_room_txn( txn: LoggingTransaction, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a6679e1312..85739c464e 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -12,25 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +import datetime +from typing import Dict, List, Tuple, Union import attr from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersion, ) from synapse.events import _EventInternalMetadata -from synapse.util import json_encoder +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.types import JsonDict +from synapse.util import Clock, json_encoder import tests.unittest import tests.utils +@attr.s(auto_attribs=True, frozen=True, slots=True) +class _BackfillSetupInfo: + room_id: str + depth_map: Dict[str, int] + + class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): @@ -571,11 +584,471 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertEqual(count, 1) - _, event_id = self.get_success( + next_staged_event_info = self.get_success( self.store.get_next_staged_event_id_for_room(room_id) ) + assert next_staged_event_info + _, event_id = next_staged_event_info self.assertEqual(event_id, "$fake_event_id_500") + def _setup_room_for_backfill_tests(self) -> _BackfillSetupInfo: + """ + Sets up a room with various events and backward extremities to test + backfill functions against. + + Returns: + _BackfillSetupInfo including the `room_id` to test against and + `depth_map` of events in the room + """ + room_id = "!backfill-room-test:some-host" + + # The silly graph we use to test grabbing backward extremities, + # where the top is the oldest events. + # 1 (oldest) + # | + # 2 ⹁ + # | \ + # | [b1, b2, b3] + # | | + # | A + # | / + # 3 { + # | \ + # | [b4, b5, b6] + # | | + # | B + # | / + # 4 ´ + # | + # 5 (newest) + + event_graph: Dict[str, List[str]] = { + "1": [], + "2": ["1"], + "3": ["2", "A"], + "4": ["3", "B"], + "5": ["4"], + "A": ["b1", "b2", "b3"], + "b1": ["2"], + "b2": ["2"], + "b3": ["2"], + "B": ["b4", "b5", "b6"], + "b4": ["3"], + "b5": ["3"], + "b6": ["3"], + } + + depth_map: Dict[str, int] = { + "1": 1, + "2": 2, + "b1": 3, + "b2": 3, + "b3": 3, + "A": 4, + "3": 5, + "b4": 6, + "b5": 6, + "b6": 6, + "B": 7, + "4": 8, + "5": 9, + } + + # The events we have persisted on our server. + # The rest are events in the room but not backfilled tet. + our_server_events = {"5", "4", "B", "3", "A"} + + complete_event_dict_map: Dict[str, JsonDict] = {} + stream_ordering = 0 + for (event_id, prev_event_ids) in event_graph.items(): + depth = depth_map[event_id] + + complete_event_dict_map[event_id] = { + "event_id": event_id, + "type": "test_regular_type", + "room_id": room_id, + "sender": "@sender", + "prev_event_ids": prev_event_ids, + "auth_event_ids": [], + "origin_server_ts": stream_ordering, + "depth": depth, + "stream_ordering": stream_ordering, + "content": {"body": "event" + event_id}, + } + + stream_ordering += 1 + + def populate_db(txn: LoggingTransaction): + # Insert the room to satisfy the foreign key constraint of + # `event_failed_pull_attempts` + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + }, + ) + + # Insert our server events + for event_id in our_server_events: + event_dict = complete_event_dict_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_dict.get("event_id"), + "type": event_dict.get("type"), + "room_id": event_dict.get("room_id"), + "depth": event_dict.get("depth"), + "topological_ordering": event_dict.get("depth"), + "stream_ordering": event_dict.get("stream_ordering"), + "processed": True, + "outlier": False, + }, + ) + + # Insert the event edges + for event_id in our_server_events: + for prev_event_id in event_graph[event_id]: + self.store.db_pool.simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event_id, + "prev_event_id": prev_event_id, + "room_id": room_id, + }, + ) + + # Insert the backward extremities + prev_events_of_our_events = { + prev_event_id + for our_server_event in our_server_events + for prev_event_id in complete_event_dict_map[our_server_event][ + "prev_event_ids" + ] + } + backward_extremities = prev_events_of_our_events - our_server_events + for backward_extremity in backward_extremities: + self.store.db_pool.simple_insert_txn( + txn, + table="event_backward_extremities", + values={ + "event_id": backward_extremity, + "room_id": room_id, + }, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "_setup_room_for_backfill_tests_populate_db", + populate_db, + ) + ) + + return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) + + def test_get_backfill_points_in_room(self): + """ + Test to make sure we get some backfill points + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] + ) + + def test_get_backfill_points_in_room_excludes_events_we_have_attempted( + self, + ): + """ + Test to make sure that events we have attempted to backfill (and within + backoff timeout duration) do not show up as an event to backfill again. + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_backfill_points_in_room` exclude them because we + # haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b5", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b4", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b2", "fake cause") + ) + + # No time has passed since we attempted to backfill ^ + + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Only the backfill points that we didn't record earlier exist here. + self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"]) + + def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( + self, + ): + """ + Test to make sure after we fake attempt to backfill event "b3" many times, + we can see retry and see the "b3" again after the backoff timeout duration + has exceeded. + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_backfill_points_in_room` exclude them because we + # haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + + # Now advance time by 2 hours and we should only be able to see "b3" + # because we have waited long enough for the single attempt (2^1 hours) + # but we still shouldn't see "b1" because we haven't waited long enough + # for this many attempts. We didn't do anything to "b2" so it should be + # visible regardless. + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + # Make sure that "b1" is not in the list because we've + # already attempted many times + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"]) + + # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and + # see if we can now backfill it + self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) + + # Try again after we advanced enough time and we should see "b3" again + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] + ) + + def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: + """ + Sets up a room with various insertion event backward extremities to test + backfill functions against. + + Returns: + _BackfillSetupInfo including the `room_id` to test against and + `depth_map` of events in the room + """ + room_id = "!backfill-room-test:some-host" + + depth_map: Dict[str, int] = { + "1": 1, + "2": 2, + "insertion_eventA": 3, + "3": 4, + "insertion_eventB": 5, + "4": 6, + "5": 7, + } + + def populate_db(txn: LoggingTransaction): + # Insert the room to satisfy the foreign key constraint of + # `event_failed_pull_attempts` + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + }, + ) + + # Insert our server events + stream_ordering = 0 + for event_id, depth in depth_map.items(): + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "type": EventTypes.MSC2716_INSERTION + if event_id.startswith("insertion_event") + else "test_regular_type", + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "stream_ordering": stream_ordering, + "processed": True, + "outlier": False, + }, + ) + + if event_id.startswith("insertion_event"): + self.store.db_pool.simple_insert_txn( + txn, + table="insertion_event_extremities", + values={ + "event_id": event_id, + "room_id": room_id, + }, + ) + + stream_ordering += 1 + + self.get_success( + self.store.db_pool.runInteraction( + "_setup_room_for_insertion_backfill_tests_populate_db", + populate_db, + ) + ) + + return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) + + def test_get_insertion_event_backward_extremities_in_room(self): + """ + Test to make sure insertion event backward extremities are returned. + """ + setup_info = self._setup_room_for_insertion_backfill_tests() + room_id = setup_info.room_id + + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["insertion_eventB", "insertion_eventA"] + ) + + def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( + self, + ): + """ + Test to make sure that insertion events we have attempted to backfill + (and within backoff timeout duration) do not show up as an event to + backfill again. + """ + setup_info = self._setup_room_for_insertion_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_insertion_event_backward_extremities_in_room` exclude them + # because we haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + + # No time has passed since we attempted to backfill ^ + + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Only the backfill points that we didn't record earlier exist here. + self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + + def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( + self, + ): + """ + Test to make sure after we fake attempt to backfill event + "insertion_eventA" many times, we can see retry and see the + "insertion_eventA" again after the backoff timeout duration has + exceeded. + """ + setup_info = self._setup_room_for_insertion_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_backfill_points_in_room` exclude them because we + # haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventB", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + + # Now advance time by 2 hours and we should only be able to see + # "insertion_eventB" because we have waited long enough for the single + # attempt (2^1 hours) but we still shouldn't see "insertion_eventA" + # because we haven't waited long enough for this many attempts. + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + # Make sure that "insertion_eventA" is not in the list because we've + # already attempted many times + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + + # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and + # see if we can now backfill it + self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) + + # Try at "insertion_eventA" again after we advanced enough time and we + # should see "insertion_eventA" again + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["insertion_eventB", "insertion_eventA"] + ) + @attr.s class FakeEvent: -- cgit 1.5.1 From dcdd50e458e7f6c77e1ca28afb300d9f0ab490b3 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 26 Sep 2022 13:30:00 +0100 Subject: Fix mypy errors with latest canonicaljson (#13905) * Lockfile: update canonicaljson 1.6.0 -> 1.6.3 * Fix mypy errors with latest canonicaljson The change to `_encode_json_bytes` definition wasn't sufficient: ``` synapse/http/server.py:751: error: Incompatible types in assignment (expression has type "Callable[[Arg(object, 'json_object')], bytes]", variable has type "Callable[[Arg(object, 'data')], bytes]") [assignment] ``` Which I think is mypy warning us that the two functions accept different sets of kwargs. Fair enough! * Changelog --- changelog.d/13905.misc | 1 + poetry.lock | 9 +++++---- synapse/http/server.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13905.misc (limited to 'synapse') diff --git a/changelog.d/13905.misc b/changelog.d/13905.misc new file mode 100644 index 0000000000..efe3bed5f1 --- /dev/null +++ b/changelog.d/13905.misc @@ -0,0 +1 @@ +Fix mypy errors with canonicaljson 1.6.3. diff --git a/poetry.lock b/poetry.lock index 291f3c51e6..0f6d1cfa69 100644 --- a/poetry.lock +++ b/poetry.lock @@ -95,14 +95,15 @@ webencodings = "*" [[package]] name = "canonicaljson" -version = "1.6.0" +version = "1.6.3" description = "Canonical JSON" category = "main" optional = false -python-versions = "~=3.7" +python-versions = ">=3.7" [package.dependencies] simplejson = ">=3.14.0" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.8\""} [package.extras] frozendict = ["frozendict (>=1.0)"] @@ -1682,8 +1683,8 @@ bleach = [ {file = "bleach-4.1.0.tar.gz", hash = "sha256:0900d8b37eba61a802ee40ac0061f8c2b5dee29c1927dd1d233e075ebf5a71da"}, ] canonicaljson = [ - {file = "canonicaljson-1.6.0-py3-none-any.whl", hash = "sha256:7230c2a2a3db07874f622af84effe41a655e07bf23734830e18a454e65d5b998"}, - {file = "canonicaljson-1.6.0.tar.gz", hash = "sha256:8739d5fd91aca7281d425660ae65af7663808c8177778965f67e90b16a2b2427"}, + {file = "canonicaljson-1.6.3-py3-none-any.whl", hash = "sha256:6ba3cf1702fa3d209b3e915a4e9a3e4ef194f1e8fca189c1f0b7a2a7686a27e6"}, + {file = "canonicaljson-1.6.3.tar.gz", hash = "sha256:ca59760bc274a899a0da75809d6909ae43e5123381fd6ef040a44d1952c0b448"}, ] certifi = [ {file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"}, diff --git a/synapse/http/server.py b/synapse/http/server.py index 6068a94b40..bcbfac2c9f 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -705,7 +705,7 @@ class _ByteProducer: self._request = None -def _encode_json_bytes(json_object: Any) -> bytes: +def _encode_json_bytes(json_object: object) -> bytes: """ Encode an object into JSON. Returns an iterator of bytes. """ @@ -746,7 +746,7 @@ def respond_with_json( return None if canonical_json: - encoder = encode_canonical_json + encoder: Callable[[object], bytes] = encode_canonical_json else: encoder = _encode_json_bytes -- cgit 1.5.1 From 6b4593a80fa2fd9ec8e1ec82fad74f3b7fbb9ba3 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 26 Sep 2022 16:26:35 +0100 Subject: Simplify cache invalidation after event persist txn (#13796) This moves all the invalidations into a single place and de-duplicates the code involved in invalidating caches for a given event by using the base class method. --- changelog.d/13796.misc | 1 + synapse/storage/_base.py | 3 + synapse/storage/databases/main/cache.py | 34 +++++--- synapse/storage/databases/main/events.py | 133 +++++++------------------------ 4 files changed, 52 insertions(+), 119 deletions(-) create mode 100644 changelog.d/13796.misc (limited to 'synapse') diff --git a/changelog.d/13796.misc b/changelog.d/13796.misc new file mode 100644 index 0000000000..9ed1662394 --- /dev/null +++ b/changelog.d/13796.misc @@ -0,0 +1 @@ +Use shared methods for cache invalidation when persisting events, remove duplicate codepaths. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 303a5d5298..313e8aca7d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -91,6 +91,9 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "get_user_in_room_with_profile", (room_id, user_id) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (user_id,) + ) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2c421151c1..db6ce83a2b 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -223,15 +223,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # process triggering the invalidation is responsible for clearing any external # cached objects. self._invalidate_local_get_event_cache(event_id) - self.have_seen_event.invalidate((room_id, event_id)) - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + self._attempt_to_invalidate_cache("have_seen_event", (room_id, event_id)) + self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) + self._attempt_to_invalidate_cache( + "get_unread_event_push_actions_by_room_for_user", (room_id,) + ) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. - self._get_membership_from_event_id.invalidate((event_id,)) + self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,)) if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) @@ -240,19 +241,26 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. - self.get_relations_for_event.invalidate((redacts,)) - self.get_applicable_edit.invalidate((redacts,)) + self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) + self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) + self._attempt_to_invalidate_cache( + "get_invited_rooms_for_local_user", (state_key,) + ) if relates_to: - self.get_relations_for_event.invalidate((relates_to,)) - self.get_aggregation_groups_for_event.invalidate((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) - self.get_thread_summary.invalidate((relates_to,)) - self.get_thread_participated.invalidate((relates_to,)) + self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache( + "get_aggregation_groups_for_event", (relates_to,) + ) + self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) + self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) + self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) + self._attempt_to_invalidate_cache( + "get_mutual_event_relations_for_rel_type", (relates_to,) + ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1b54a2eb57..2e156a4a11 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -410,6 +410,31 @@ class PersistEventsStore: assert min_stream_order assert max_stream_order + # Once the txn completes, invalidate all of the relevant caches. Note that we do this + # up here because it captures all the events_and_contexts before any are removed. + for event, _ in events_and_contexts: + self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) + if event.redacts: + self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) + + relates_to = None + relation = relation_from_event(event) + if relation: + relates_to = relation.parent_id + + assert event.internal_metadata.stream_ordering is not None + txn.call_after( + self.store._invalidate_caches_for_event, + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.type, + getattr(event, "state_key", None), + event.redacts, + relates_to, + backfilled=False, + ) + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremities, @@ -459,6 +484,7 @@ class PersistEventsStore: # We call this last as it assumes we've inserted the events into # room_memberships, where applicable. + # NB: This function invalidates all state related caches self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) def _persist_event_auth_chain_txn( @@ -1172,13 +1198,6 @@ class PersistEventsStore: ) # Invalidate the various caches - - for member in members_changed: - txn.call_after( - self.store.get_rooms_for_user_with_stream_ordering.invalidate, - (member,), - ) - self.store._invalidate_state_caches_and_stream( txn, room_id, members_changed ) @@ -1222,9 +1241,6 @@ class PersistEventsStore: self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) - txn.call_after( - self.store.get_latest_event_ids_in_room.invalidate, (room_id,) - ) self.db_pool.simple_insert_many_txn( txn, @@ -1294,8 +1310,6 @@ class PersistEventsStore: """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: - # Remove the any existing cache entries for the event_ids - self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) # Then update the `stream_ordering` position to mark the latest # event as the front of the room. This should not be done for # backfilled events because backfilled events have negative @@ -1697,16 +1711,7 @@ class PersistEventsStore: txn.async_call_after(prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: - """Invalidate the caches for the redacted event. - - Note that these caches are also cleared as part of event replication in - _invalidate_caches_for_event. - """ assert event.redacts is not None - self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) - txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) - txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) - self.db_pool.simple_upsert_txn( txn, table="redactions", @@ -1807,34 +1812,6 @@ class PersistEventsStore: for event in events: assert event.internal_metadata.stream_ordering is not None - txn.call_after( - self.store._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.store.get_invited_rooms_for_local_user.invalidate, - (event.state_key,), - ) - txn.call_after( - self.store.get_local_users_in_room.invalidate, - (event.room_id,), - ) - txn.call_after( - self.store.get_number_joined_users_in_room.invalidate, - (event.room_id,), - ) - txn.call_after( - self.store.get_user_in_room_with_profile.invalidate, - (event.room_id, event.state_key), - ) - - # The `_get_membership_from_event_id` is immutable, except for the - # case where we look up an event *before* persisting it. - txn.call_after( - self.store._get_membership_from_event_id.invalidate, - (event.event_id,), - ) # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. @@ -1883,35 +1860,6 @@ class PersistEventsStore: }, ) - txn.call_after( - self.store.get_relations_for_event.invalidate, (relation.parent_id,) - ) - txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate, - (relation.parent_id,), - ) - txn.call_after( - self.store.get_mutual_event_relations_for_rel_type.invalidate, - (relation.parent_id,), - ) - - if relation.rel_type == RelationTypes.REPLACE: - txn.call_after( - self.store.get_applicable_edit.invalidate, (relation.parent_id,) - ) - - if relation.rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (relation.parent_id,) - ) - # It should be safe to only invalidate the cache if the user has not - # previously participated in the thread, but that's difficult (and - # potentially error-prone) so it is always invalidated. - txn.call_after( - self.store.get_thread_participated.invalidate, - (relation.parent_id, event.sender), - ) - def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -2213,28 +2161,6 @@ class PersistEventsStore: ), ) - room_to_event_ids: Dict[str, List[str]] = {} - for e in non_outlier_events: - room_to_event_ids.setdefault(e.room_id, []).append(e.event_id) - - for room_id, event_ids in room_to_event_ids.items(): - rows = self.db_pool.simple_select_many_txn( - txn, - table="event_push_actions_staging", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("user_id",), - ) - - user_ids = {row["user_id"] for row in rows} - - for user_id in user_ids: - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate, - (room_id, user_id), - ) - # Now we delete the staging area for *all* events that were being # persisted. txn.execute_batch( @@ -2249,11 +2175,6 @@ class PersistEventsStore: def _remove_push_actions_for_event_id_txn( self, txn: LoggingTransaction, room_id: str, event_id: str ) -> None: - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate, - (room_id,), - ) txn.execute( "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", (room_id, event_id), -- cgit 1.5.1 From 41461fd4d63e55d1812f0688ca58a88e7200a1d7 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 26 Sep 2022 17:33:32 +0200 Subject: typing: check origin server of typing event against room's servers (#13830) This is also using the partial state approximation if needed so we do not block here during a fast join. Signed-off-by: Mathieu Velten Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/13830.bugfix | 1 + synapse/handlers/typing.py | 7 +++++-- tests/handlers/test_typing.py | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13830.bugfix (limited to 'synapse') diff --git a/changelog.d/13830.bugfix b/changelog.d/13830.bugfix new file mode 100644 index 0000000000..e6215806cd --- /dev/null +++ b/changelog.d/13830.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where typing events would be accepted from remote servers not present in a room. Also fix a bug where incoming typing events would cause other incoming events to get stuck during a fast join. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0d8466af11..f953691669 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -362,11 +362,14 @@ class TypingWriterHandler(FollowerTypingHandler): ) return - domains = await self._storage_controllers.state.get_current_hosts_in_room( + # Let's check that the origin server is in the room before accepting the typing + # event. We don't want to block waiting on a partial state so take an + # approximation if needed. + domains = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) - if self.server_name in domains: + if user.domain in domains: logger.info("Got typing update from %s: %r", user_id, content) now = self.clock.time_msec() self._member_typing_until[member] = now + FEDERATION_TIMEOUT diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1a247f12e8..9c821b3042 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -138,6 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): get_current_hosts_in_room ) + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( + get_current_hosts_in_room + ) + async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} -- cgit 1.5.1 From 0a38c7ec6d46b6e51bfa53ff44e51637d3c63f5c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 26 Sep 2022 18:28:32 +0100 Subject: Snapshot schema 72 (#13873) Including another batch of fixes to the schema dump script --- changelog.d/13873.misc | 1 + scripts-dev/make_full_schema.sh | 60 +- synapse/storage/database.py | 8 + synapse/storage/engines/_base.py | 23 +- synapse/storage/engines/postgres.py | 12 +- synapse/storage/engines/sqlite.py | 21 +- synapse/storage/prepare_database.py | 8 +- .../common/full_schemas/72/full.sql.postgres | 8 + .../schema/common/full_schemas/72/full.sql.sqlite | 6 + .../schema/main/full_schemas/72/full.sql.postgres | 1344 ++++++++++++++++++++ .../schema/main/full_schemas/72/full.sql.sqlite | 646 ++++++++++ .../schema/state/full_schemas/72/full.sql.postgres | 30 + .../schema/state/full_schemas/72/full.sql.sqlite | 20 + 13 files changed, 2165 insertions(+), 22 deletions(-) create mode 100644 changelog.d/13873.misc create mode 100644 synapse/storage/schema/common/full_schemas/72/full.sql.postgres create mode 100644 synapse/storage/schema/common/full_schemas/72/full.sql.sqlite create mode 100644 synapse/storage/schema/main/full_schemas/72/full.sql.postgres create mode 100644 synapse/storage/schema/main/full_schemas/72/full.sql.sqlite create mode 100644 synapse/storage/schema/state/full_schemas/72/full.sql.postgres create mode 100644 synapse/storage/schema/state/full_schemas/72/full.sql.sqlite (limited to 'synapse') diff --git a/changelog.d/13873.misc b/changelog.d/13873.misc new file mode 100644 index 0000000000..f4342482f0 --- /dev/null +++ b/changelog.d/13873.misc @@ -0,0 +1 @@ +Create a new snapshot of the database schema. diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh index d8cd06ee4f..e2bc1640bb 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh @@ -26,6 +26,9 @@ usage() { echo " Defaults to 9999." echo "-h" echo " Display this help text." + echo "" + echo " NB: make sure to run this against the *oldest* supported version of postgres," + echo " or else pg_dump might output non-backwards-compatible syntax." } SCHEMA_NUMBER="9999" @@ -240,25 +243,54 @@ DROP TABLE user_directory_search_stat; echo "Dumping SQLite3 schema..." -mkdir -p "$OUTPUT_DIR/"{common,main,state}"/full_schema/$SCHEMA_NUMBER" -sqlite3 "$SQLITE_COMMON_DB" ".schema --indent" > "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" -sqlite3 "$SQLITE_COMMON_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" -sqlite3 "$SQLITE_MAIN_DB" ".schema --indent" > "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" -sqlite3 "$SQLITE_MAIN_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" -sqlite3 "$SQLITE_STATE_DB" ".schema --indent" > "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" -sqlite3 "$SQLITE_STATE_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.sqlite" +mkdir -p "$OUTPUT_DIR/"{common,main,state}"/full_schemas/$SCHEMA_NUMBER" +sqlite3 "$SQLITE_COMMON_DB" ".schema" > "$OUTPUT_DIR/common/full_schemas/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_COMMON_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/common/full_schemas/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_MAIN_DB" ".schema" > "$OUTPUT_DIR/main/full_schemas/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_MAIN_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/main/full_schemas/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_STATE_DB" ".schema" > "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.sqlite" +sqlite3 "$SQLITE_STATE_DB" ".dump --data-only --nosys" >> "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.sqlite" cleanup_pg_schema() { - sed -e '/^$/d' -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' + # Cleanup as follows: + # - Remove empty lines. pg_dump likes to output a lot of these. + # - Remove comment-only lines. pg_dump also likes to output a lot of these to visually + # separate tables etc. + # - Remove "public." prefix --- the schema name. + # - Remove "SET" commands. Last time I ran this, the output commands were + # SET statement_timeout = 0; + # SET lock_timeout = 0; + # SET idle_in_transaction_session_timeout = 0; + # SET client_encoding = 'UTF8'; + # SET standard_conforming_strings = on; + # SET check_function_bodies = false; + # SET xmloption = content; + # SET client_min_messages = warning; + # SET row_security = off; + # SET default_table_access_method = heap; + # - Very carefully remove specific SELECT statements. We CANNOT blanket remove all + # SELECT statements because some of those have side-effects which we do want in the + # schema. Last time I ran this, the only SELECTS were + # SELECT pg_catalog.set_config('search_path', '', false); + # and + # SELECT pg_catalog.setval(text, bigint, bool); + # We do want to remove the former, but the latter is important. If the last argument + # is `true` or omitted, this marks the given integer as having been consumed and + # will NOT appear as the nextval. + sed -e '/^$/d' \ + -e '/^--/d' \ + -e 's/public\.//g' \ + -e '/^SET /d' \ + -e '/^SELECT pg_catalog.set_config/d' } echo "Dumping Postgres schema..." -pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_COMMON_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.postgres" -pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_COMMON_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/common/full_schema/$SCHEMA_NUMBER/full.sql.postgres" -pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_MAIN_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.postgres" -pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_MAIN_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/main/full_schema/$SCHEMA_NUMBER/full.sql.postgres" -pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.postgres" -pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/state/full_schema/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_COMMON_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/common/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_COMMON_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/common/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_MAIN_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/main/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_MAIN_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/main/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" +pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" echo "Done! Files dumped to: $OUTPUT_DIR" diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 9d116f6925..6cc88aad32 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -393,6 +393,14 @@ class LoggingTransaction: def executemany(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.executemany, sql, *args) + def executescript(self, sql: str) -> None: + if isinstance(self.database_engine, Sqlite3Engine): + self._do_execute(self.txn.executescript, sql) # type: ignore[attr-defined] + else: + raise NotImplementedError( + f"executescript only exists for sqlite driver, not {type(self.database_engine)}" + ) + def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 0d16a419a4..70e594a68f 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -32,9 +32,10 @@ class IncorrectDatabaseSetup(RuntimeError): ConnectionType = TypeVar("ConnectionType", bound=Connection) +CursorType = TypeVar("CursorType", bound=Cursor) -class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): +class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCMeta): def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]): self.module = module @@ -64,7 +65,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): ... @abc.abstractmethod - def check_new_database(self, txn: Cursor) -> None: + def check_new_database(self, txn: CursorType) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ @@ -124,3 +125,21 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default. """ ... + + @staticmethod + @abc.abstractmethod + def executescript(cursor: CursorType, script: str) -> None: + """Execute a chunk of SQL containing multiple semicolon-delimited statements. + + This is not provided by DBAPI2, and so needs engine-specific support. + """ + ... + + @classmethod + def execute_script_file(cls, cursor: CursorType, filepath: str) -> None: + """Execute a file containing multiple semicolon-delimited SQL statements. + + This is not provided by DBAPI2, and so needs engine-specific support. + """ + with open(filepath, "rt") as f: + cls.executescript(cursor, f.read()) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 7f7d006ac2..d8c0f64d9a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -31,7 +31,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): +class PostgresEngine( + BaseDatabaseEngine[psycopg2.extensions.connection, psycopg2.extensions.cursor] +): def __init__(self, database_config: Mapping[str, Any]): super().__init__(psycopg2, database_config) psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) @@ -212,3 +214,11 @@ class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): else: isolation_level = self.isolation_level_map[isolation_level] return conn.set_isolation_level(isolation_level) + + @staticmethod + def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None: + """Execute a chunk of SQL containing multiple semicolon-delimited statements. + + Psycopg2 seems happy to do this in DBAPI2's `execute()` function. + """ + cursor.execute(script) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 095ae0a096..faa574dbfd 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from synapse.storage.database import LoggingDatabaseConnection -class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): def __init__(self, database_config: Mapping[str, Any]): super().__init__(sqlite3, database_config) @@ -120,6 +120,25 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): # All transactions are SERIALIZABLE by default in sqlite pass + @staticmethod + def executescript(cursor: sqlite3.Cursor, script: str) -> None: + """Execute a chunk of SQL containing multiple semicolon-delimited statements. + + Python's built-in SQLite driver does not allow you to do this with DBAPI2's + `execute`: + + > execute() will only execute a single SQL statement. If you try to execute more + > than one statement with it, it will raise a Warning. Use executescript() if + > you want to execute multiple SQL statements with one call. + + Though the docs for `executescript` warn: + + > If there is a pending transaction, an implicit COMMIT statement is executed + > first. No other implicit transaction control is performed; any transaction + > control must be added to sql_script. + """ + cursor.executescript(script) + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 09a2b58f4c..3acdb39da7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -266,7 +266,7 @@ def _setup_new_database( ".sql." + specific ): logger.debug("Applying schema %s", entry.absolute_path) - executescript(cur, entry.absolute_path) + database_engine.execute_script_file(cur, entry.absolute_path) cur.execute( "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", @@ -517,7 +517,7 @@ def _upgrade_existing_database( UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path ) logger.info("Applying schema %s", relative_path) - executescript(cur, absolute_path) + database_engine.execute_script_file(cur, absolute_path) elif ext == specific_engine_extension and root_name.endswith(".sql"): # A .sql file specific to our engine; just read and execute it if is_worker: @@ -525,7 +525,7 @@ def _upgrade_existing_database( UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path ) logger.info("Applying engine-specific schema %s", relative_path) - executescript(cur, absolute_path) + database_engine.execute_script_file(cur, absolute_path) elif ext in specific_engine_extensions and root_name.endswith(".sql"): # A .sql file for a different engine; skip it. continue @@ -666,7 +666,7 @@ def _get_or_create_schema_state( ) -> Optional[_SchemaState]: # Bluntly try creating the schema_version tables. sql_path = os.path.join(schema_path, "common", "schema_version.sql") - executescript(txn, sql_path) + database_engine.execute_script_file(txn, sql_path) txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() diff --git a/synapse/storage/schema/common/full_schemas/72/full.sql.postgres b/synapse/storage/schema/common/full_schemas/72/full.sql.postgres new file mode 100644 index 0000000000..f0e546f052 --- /dev/null +++ b/synapse/storage/schema/common/full_schemas/72/full.sql.postgres @@ -0,0 +1,8 @@ +CREATE TABLE background_updates ( + update_name text NOT NULL, + progress_json text NOT NULL, + depends_on text, + ordering integer DEFAULT 0 NOT NULL +); +ALTER TABLE ONLY background_updates + ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name); diff --git a/synapse/storage/schema/common/full_schemas/72/full.sql.sqlite b/synapse/storage/schema/common/full_schemas/72/full.sql.sqlite new file mode 100644 index 0000000000..d5a2c04aa9 --- /dev/null +++ b/synapse/storage/schema/common/full_schemas/72/full.sql.sqlite @@ -0,0 +1,6 @@ +CREATE TABLE background_updates ( + update_name text NOT NULL, + progress_json text NOT NULL, + depends_on text, ordering INT NOT NULL DEFAULT 0, + CONSTRAINT background_updates_uniqueness UNIQUE (update_name) +); diff --git a/synapse/storage/schema/main/full_schemas/72/full.sql.postgres b/synapse/storage/schema/main/full_schemas/72/full.sql.postgres new file mode 100644 index 0000000000..d421fd9ab9 --- /dev/null +++ b/synapse/storage/schema/main/full_schemas/72/full.sql.postgres @@ -0,0 +1,1344 @@ +CREATE FUNCTION check_partial_state_events() RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ) THEN + RAISE EXCEPTION 'Incorrect room_id in partial_state_events'; + END IF; + RETURN NEW; + END; + $$; +CREATE TABLE access_tokens ( + id bigint NOT NULL, + user_id text NOT NULL, + device_id text, + token text NOT NULL, + valid_until_ms bigint, + puppets_user_id text, + last_validated bigint, + refresh_token_id bigint, + used boolean +); +CREATE TABLE account_data ( + user_id text NOT NULL, + account_data_type text NOT NULL, + stream_id bigint NOT NULL, + content text NOT NULL, + instance_name text +); +CREATE SEQUENCE account_data_sequence + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE account_validity ( + user_id text NOT NULL, + expiration_ts_ms bigint NOT NULL, + email_sent boolean NOT NULL, + renewal_token text, + token_used_ts_ms bigint +); +CREATE TABLE application_services_state ( + as_id text NOT NULL, + state character varying(5), + read_receipt_stream_id bigint, + presence_stream_id bigint, + to_device_stream_id bigint, + device_list_stream_id bigint +); +CREATE SEQUENCE application_services_txn_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE application_services_txns ( + as_id text NOT NULL, + txn_id bigint NOT NULL, + event_ids text NOT NULL +); +CREATE TABLE appservice_room_list ( + appservice_id text NOT NULL, + network_id text NOT NULL, + room_id text NOT NULL +); +CREATE TABLE appservice_stream_position ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_ordering bigint, + CONSTRAINT appservice_stream_position_lock_check CHECK ((lock = 'X'::bpchar)) +); +CREATE TABLE batch_events ( + event_id text NOT NULL, + room_id text NOT NULL, + batch_id text NOT NULL +); +CREATE TABLE blocked_rooms ( + room_id text NOT NULL, + user_id text NOT NULL +); +CREATE TABLE cache_invalidation_stream_by_instance ( + stream_id bigint NOT NULL, + instance_name text NOT NULL, + cache_func text NOT NULL, + keys text[], + invalidation_ts bigint +); +CREATE SEQUENCE cache_invalidation_stream_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE current_state_delta_stream ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + event_id text, + prev_event_id text, + instance_name text +); +CREATE TABLE current_state_events ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + membership text +); +CREATE TABLE dehydrated_devices ( + user_id text NOT NULL, + device_id text NOT NULL, + device_data text NOT NULL +); +CREATE TABLE deleted_pushers ( + stream_id bigint NOT NULL, + app_id text NOT NULL, + pushkey text NOT NULL, + user_id text NOT NULL +); +CREATE TABLE destination_rooms ( + destination text NOT NULL, + room_id text NOT NULL, + stream_ordering bigint NOT NULL +); +CREATE TABLE destinations ( + destination text NOT NULL, + retry_last_ts bigint, + retry_interval bigint, + failure_ts bigint, + last_successful_stream_ordering bigint +); +CREATE TABLE device_auth_providers ( + user_id text NOT NULL, + device_id text NOT NULL, + auth_provider_id text NOT NULL, + auth_provider_session_id text NOT NULL +); +CREATE TABLE device_federation_inbox ( + origin text NOT NULL, + message_id text NOT NULL, + received_ts bigint NOT NULL, + instance_name text +); +CREATE TABLE device_federation_outbox ( + destination text NOT NULL, + stream_id bigint NOT NULL, + queued_ts bigint NOT NULL, + messages_json text NOT NULL, + instance_name text +); +CREATE TABLE device_inbox ( + user_id text NOT NULL, + device_id text NOT NULL, + stream_id bigint NOT NULL, + message_json text NOT NULL, + instance_name text +); +CREATE SEQUENCE device_inbox_sequence + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE device_lists_changes_in_room ( + user_id text NOT NULL, + device_id text NOT NULL, + room_id text NOT NULL, + stream_id bigint NOT NULL, + converted_to_destinations boolean NOT NULL, + opentracing_context text +); +CREATE TABLE device_lists_outbound_last_success ( + destination text NOT NULL, + user_id text NOT NULL, + stream_id bigint NOT NULL +); +CREATE TABLE device_lists_outbound_pokes ( + destination text NOT NULL, + stream_id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL, + sent boolean NOT NULL, + ts bigint NOT NULL, + opentracing_context text +); +CREATE TABLE device_lists_remote_cache ( + user_id text NOT NULL, + device_id text NOT NULL, + content text NOT NULL +); +CREATE TABLE device_lists_remote_extremeties ( + user_id text NOT NULL, + stream_id text NOT NULL +); +CREATE TABLE device_lists_remote_resync ( + user_id text NOT NULL, + added_ts bigint NOT NULL +); +CREATE TABLE device_lists_stream ( + stream_id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL +); +CREATE TABLE devices ( + user_id text NOT NULL, + device_id text NOT NULL, + display_name text, + last_seen bigint, + ip text, + user_agent text, + hidden boolean DEFAULT false +); +CREATE TABLE e2e_cross_signing_keys ( + user_id text NOT NULL, + keytype text NOT NULL, + keydata text NOT NULL, + stream_id bigint NOT NULL +); +CREATE TABLE e2e_cross_signing_signatures ( + user_id text NOT NULL, + key_id text NOT NULL, + target_user_id text NOT NULL, + target_device_id text NOT NULL, + signature text NOT NULL +); +CREATE TABLE e2e_device_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + ts_added_ms bigint NOT NULL, + key_json text NOT NULL +); +CREATE TABLE e2e_fallback_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + algorithm text NOT NULL, + key_id text NOT NULL, + key_json text NOT NULL, + used boolean DEFAULT false NOT NULL +); +CREATE TABLE e2e_one_time_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + algorithm text NOT NULL, + key_id text NOT NULL, + ts_added_ms bigint NOT NULL, + key_json text NOT NULL +); +CREATE TABLE e2e_room_keys ( + user_id text NOT NULL, + room_id text NOT NULL, + session_id text NOT NULL, + version bigint NOT NULL, + first_message_index integer, + forwarded_count integer, + is_verified boolean, + session_data text NOT NULL +); +CREATE TABLE e2e_room_keys_versions ( + user_id text NOT NULL, + version bigint NOT NULL, + algorithm text NOT NULL, + auth_data text NOT NULL, + deleted smallint DEFAULT 0 NOT NULL, + etag bigint +); +CREATE TABLE erased_users ( + user_id text NOT NULL +); +CREATE TABLE event_auth ( + event_id text NOT NULL, + auth_id text NOT NULL, + room_id text NOT NULL +); +CREATE SEQUENCE event_auth_chain_id + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE event_auth_chain_links ( + origin_chain_id bigint NOT NULL, + origin_sequence_number bigint NOT NULL, + target_chain_id bigint NOT NULL, + target_sequence_number bigint NOT NULL +); +CREATE TABLE event_auth_chain_to_calculate ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL +); +CREATE TABLE event_auth_chains ( + event_id text NOT NULL, + chain_id bigint NOT NULL, + sequence_number bigint NOT NULL +); +CREATE TABLE event_backward_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); +CREATE TABLE event_edges ( + event_id text NOT NULL, + prev_event_id text NOT NULL, + room_id text, + is_state boolean DEFAULT false NOT NULL +); +CREATE TABLE event_expiry ( + event_id text NOT NULL, + expiry_ts bigint NOT NULL +); +CREATE TABLE event_forward_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); +CREATE TABLE event_json ( + event_id text NOT NULL, + room_id text NOT NULL, + internal_metadata text NOT NULL, + json text NOT NULL, + format_version integer +); +CREATE TABLE event_labels ( + event_id text NOT NULL, + label text NOT NULL, + room_id text NOT NULL, + topological_ordering bigint NOT NULL +); +CREATE TABLE event_push_actions ( + room_id text NOT NULL, + event_id text NOT NULL, + user_id text NOT NULL, + profile_tag character varying(32), + actions text NOT NULL, + topological_ordering bigint, + stream_ordering bigint, + notif smallint, + highlight smallint, + unread smallint, + thread_id text +); +CREATE TABLE event_push_actions_staging ( + event_id text NOT NULL, + user_id text NOT NULL, + actions text NOT NULL, + notif smallint NOT NULL, + highlight smallint NOT NULL, + unread smallint, + thread_id text +); +CREATE TABLE event_push_summary ( + user_id text NOT NULL, + room_id text NOT NULL, + notif_count bigint NOT NULL, + stream_ordering bigint NOT NULL, + unread_count bigint, + last_receipt_stream_ordering bigint, + thread_id text +); +CREATE TABLE event_push_summary_last_receipt_stream_id ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint NOT NULL, + CONSTRAINT event_push_summary_last_receipt_stream_id_lock_check CHECK ((lock = 'X'::bpchar)) +); +CREATE TABLE event_push_summary_stream_ordering ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_ordering bigint NOT NULL, + CONSTRAINT event_push_summary_stream_ordering_lock_check CHECK ((lock = 'X'::bpchar)) +); +CREATE TABLE event_relations ( + event_id text NOT NULL, + relates_to_id text NOT NULL, + relation_type text NOT NULL, + aggregation_key text +); +CREATE TABLE event_reports ( + id bigint NOT NULL, + received_ts bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL, + user_id text NOT NULL, + reason text, + content text +); +CREATE TABLE event_search ( + event_id text, + room_id text, + sender text, + key text, + vector tsvector, + origin_server_ts bigint, + stream_ordering bigint +); +CREATE TABLE event_to_state_groups ( + event_id text NOT NULL, + state_group bigint NOT NULL +); +CREATE TABLE event_txn_id ( + event_id text NOT NULL, + room_id text NOT NULL, + user_id text NOT NULL, + token_id bigint NOT NULL, + txn_id text NOT NULL, + inserted_ts bigint NOT NULL +); +CREATE TABLE events ( + topological_ordering bigint NOT NULL, + event_id text NOT NULL, + type text NOT NULL, + room_id text NOT NULL, + content text, + unrecognized_keys text, + processed boolean NOT NULL, + outlier boolean NOT NULL, + depth bigint DEFAULT 0 NOT NULL, + origin_server_ts bigint, + received_ts bigint, + sender text, + contains_url boolean, + instance_name text, + stream_ordering bigint, + state_key text, + rejection_reason text +); +CREATE SEQUENCE events_backfill_stream_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE SEQUENCE events_stream_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE ex_outlier_stream ( + event_stream_ordering bigint NOT NULL, + event_id text NOT NULL, + state_group bigint NOT NULL, + instance_name text +); +CREATE TABLE federation_inbound_events_staging ( + origin text NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL, + received_ts bigint NOT NULL, + event_json text NOT NULL, + internal_metadata text NOT NULL +); +CREATE TABLE federation_stream_position ( + type text NOT NULL, + stream_id bigint NOT NULL, + instance_name text DEFAULT 'master'::text NOT NULL +); +CREATE TABLE ignored_users ( + ignorer_user_id text NOT NULL, + ignored_user_id text NOT NULL +); +CREATE TABLE insertion_event_edges ( + event_id text NOT NULL, + room_id text NOT NULL, + insertion_prev_event_id text NOT NULL +); +CREATE TABLE insertion_event_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); +CREATE TABLE insertion_events ( + event_id text NOT NULL, + room_id text NOT NULL, + next_batch_id text NOT NULL +); +CREATE TABLE instance_map ( + instance_id integer NOT NULL, + instance_name text NOT NULL +); +CREATE SEQUENCE instance_map_instance_id_seq + AS integer + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +ALTER SEQUENCE instance_map_instance_id_seq OWNED BY instance_map.instance_id; +CREATE TABLE local_current_membership ( + room_id text NOT NULL, + user_id text NOT NULL, + event_id text NOT NULL, + membership text NOT NULL +); +CREATE TABLE local_media_repository ( + media_id text, + media_type text, + media_length integer, + created_ts bigint, + upload_name text, + user_id text, + quarantined_by text, + url_cache text, + last_access_ts bigint, + safe_from_quarantine boolean DEFAULT false NOT NULL +); +CREATE TABLE local_media_repository_thumbnails ( + media_id text, + thumbnail_width integer, + thumbnail_height integer, + thumbnail_type text, + thumbnail_method text, + thumbnail_length integer +); +CREATE TABLE local_media_repository_url_cache ( + url text, + response_code integer, + etag text, + expires_ts bigint, + og text, + media_id text, + download_ts bigint +); +CREATE TABLE monthly_active_users ( + user_id text NOT NULL, + "timestamp" bigint NOT NULL +); +CREATE TABLE open_id_tokens ( + token text NOT NULL, + ts_valid_until_ms bigint NOT NULL, + user_id text NOT NULL +); +CREATE TABLE partial_state_events ( + room_id text NOT NULL, + event_id text NOT NULL +); +CREATE TABLE partial_state_rooms ( + room_id text NOT NULL +); +CREATE TABLE partial_state_rooms_servers ( + room_id text NOT NULL, + server_name text NOT NULL +); +CREATE TABLE presence ( + user_id text NOT NULL, + state character varying(20), + status_msg text, + mtime bigint +); +CREATE TABLE presence_stream ( + stream_id bigint, + user_id text, + state text, + last_active_ts bigint, + last_federation_update_ts bigint, + last_user_sync_ts bigint, + status_msg text, + currently_active boolean, + instance_name text +); +CREATE SEQUENCE presence_stream_sequence + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE profiles ( + user_id text NOT NULL, + displayname text, + avatar_url text +); +CREATE TABLE push_rules ( + id bigint NOT NULL, + user_name text NOT NULL, + rule_id text NOT NULL, + priority_class smallint NOT NULL, + priority integer DEFAULT 0 NOT NULL, + conditions text NOT NULL, + actions text NOT NULL +); +CREATE TABLE push_rules_enable ( + id bigint NOT NULL, + user_name text NOT NULL, + rule_id text NOT NULL, + enabled smallint +); +CREATE TABLE push_rules_stream ( + stream_id bigint NOT NULL, + event_stream_ordering bigint NOT NULL, + user_id text NOT NULL, + rule_id text NOT NULL, + op text NOT NULL, + priority_class smallint, + priority integer, + conditions text, + actions text +); +CREATE TABLE pusher_throttle ( + pusher bigint NOT NULL, + room_id text NOT NULL, + last_sent_ts bigint, + throttle_ms bigint +); +CREATE TABLE pushers ( + id bigint NOT NULL, + user_name text NOT NULL, + access_token bigint, + profile_tag text NOT NULL, + kind text NOT NULL, + app_id text NOT NULL, + app_display_name text NOT NULL, + device_display_name text NOT NULL, + pushkey text NOT NULL, + ts bigint NOT NULL, + lang text, + data text, + last_stream_ordering bigint, + last_success bigint, + failing_since bigint +); +CREATE TABLE ratelimit_override ( + user_id text NOT NULL, + messages_per_second bigint, + burst_count bigint +); +CREATE TABLE receipts_graph ( + room_id text NOT NULL, + receipt_type text NOT NULL, + user_id text NOT NULL, + event_ids text NOT NULL, + data text NOT NULL, + thread_id text +); +CREATE TABLE receipts_linearized ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + receipt_type text NOT NULL, + user_id text NOT NULL, + event_id text NOT NULL, + data text NOT NULL, + instance_name text, + event_stream_ordering bigint, + thread_id text +); +CREATE SEQUENCE receipts_sequence + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE received_transactions ( + transaction_id text, + origin text, + ts bigint, + response_code integer, + response_json bytea, + has_been_referenced smallint DEFAULT 0 +); +CREATE TABLE redactions ( + event_id text NOT NULL, + redacts text NOT NULL, + have_censored boolean DEFAULT false NOT NULL, + received_ts bigint +); +CREATE TABLE refresh_tokens ( + id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL, + token text NOT NULL, + next_token_id bigint, + expiry_ts bigint, + ultimate_session_expiry_ts bigint +); +CREATE TABLE registration_tokens ( + token text NOT NULL, + uses_allowed integer, + pending integer NOT NULL, + completed integer NOT NULL, + expiry_time bigint +); +CREATE TABLE rejections ( + event_id text NOT NULL, + reason text NOT NULL, + last_check text NOT NULL +); +CREATE TABLE remote_media_cache ( + media_origin text, + media_id text, + media_type text, + created_ts bigint, + upload_name text, + media_length integer, + filesystem_id text, + last_access_ts bigint, + quarantined_by text +); +CREATE TABLE remote_media_cache_thumbnails ( + media_origin text, + media_id text, + thumbnail_width integer, + thumbnail_height integer, + thumbnail_method text, + thumbnail_type text, + thumbnail_length integer, + filesystem_id text +); +CREATE TABLE room_account_data ( + user_id text NOT NULL, + room_id text NOT NULL, + account_data_type text NOT NULL, + stream_id bigint NOT NULL, + content text NOT NULL, + instance_name text +); +CREATE TABLE room_alias_servers ( + room_alias text NOT NULL, + server text NOT NULL +); +CREATE TABLE room_aliases ( + room_alias text NOT NULL, + room_id text NOT NULL, + creator text +); +CREATE TABLE room_depth ( + room_id text NOT NULL, + min_depth bigint +); +CREATE TABLE room_memberships ( + event_id text NOT NULL, + user_id text NOT NULL, + sender text NOT NULL, + room_id text NOT NULL, + membership text NOT NULL, + forgotten integer DEFAULT 0, + display_name text, + avatar_url text +); +CREATE TABLE room_retention ( + room_id text NOT NULL, + event_id text NOT NULL, + min_lifetime bigint, + max_lifetime bigint +); +CREATE TABLE room_stats_current ( + room_id text NOT NULL, + current_state_events integer NOT NULL, + joined_members integer NOT NULL, + invited_members integer NOT NULL, + left_members integer NOT NULL, + banned_members integer NOT NULL, + local_users_in_room integer NOT NULL, + completed_delta_stream_id bigint NOT NULL, + knocked_members integer +); +CREATE TABLE room_stats_earliest_token ( + room_id text NOT NULL, + token bigint NOT NULL +); +CREATE TABLE room_stats_state ( + room_id text NOT NULL, + name text, + canonical_alias text, + join_rules text, + history_visibility text, + encryption text, + avatar text, + guest_access text, + is_federatable boolean, + topic text, + room_type text +); +CREATE TABLE room_tags ( + user_id text NOT NULL, + room_id text NOT NULL, + tag text NOT NULL, + content text NOT NULL +); +CREATE TABLE room_tags_revisions ( + user_id text NOT NULL, + room_id text NOT NULL, + stream_id bigint NOT NULL, + instance_name text +); +CREATE TABLE rooms ( + room_id text NOT NULL, + is_public boolean, + creator text, + room_version text, + has_auth_chain_index boolean +); +CREATE TABLE server_keys_json ( + server_name text NOT NULL, + key_id text NOT NULL, + from_server text NOT NULL, + ts_added_ms bigint NOT NULL, + ts_valid_until_ms bigint NOT NULL, + key_json bytea NOT NULL +); +CREATE TABLE server_signature_keys ( + server_name text, + key_id text, + from_server text, + ts_added_ms bigint, + verify_key bytea, + ts_valid_until_ms bigint +); +CREATE TABLE sessions ( + session_type text NOT NULL, + session_id text NOT NULL, + value text NOT NULL, + expiry_time_ms bigint NOT NULL +); +CREATE TABLE state_events ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + prev_state text +); +CREATE TABLE stats_incremental_position ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint NOT NULL, + CONSTRAINT stats_incremental_position_lock_check CHECK ((lock = 'X'::bpchar)) +); +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL +); +CREATE TABLE stream_positions ( + stream_name text NOT NULL, + instance_name text NOT NULL, + stream_id bigint NOT NULL +); +CREATE TABLE threepid_guest_access_tokens ( + medium text, + address text, + guest_access_token text, + first_inviter text +); +CREATE TABLE threepid_validation_session ( + session_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + client_secret text NOT NULL, + last_send_attempt bigint NOT NULL, + validated_at bigint +); +CREATE TABLE threepid_validation_token ( + token text NOT NULL, + session_id text NOT NULL, + next_link text, + expires bigint NOT NULL +); +CREATE TABLE ui_auth_sessions ( + session_id text NOT NULL, + creation_time bigint NOT NULL, + serverdict text NOT NULL, + clientdict text NOT NULL, + uri text NOT NULL, + method text NOT NULL, + description text NOT NULL +); +CREATE TABLE ui_auth_sessions_credentials ( + session_id text NOT NULL, + stage_type text NOT NULL, + result text NOT NULL +); +CREATE TABLE ui_auth_sessions_ips ( + session_id text NOT NULL, + ip text NOT NULL, + user_agent text NOT NULL +); +CREATE TABLE user_daily_visits ( + user_id text NOT NULL, + device_id text, + "timestamp" bigint NOT NULL, + user_agent text +); +CREATE TABLE user_directory ( + user_id text NOT NULL, + room_id text, + display_name text, + avatar_url text +); +CREATE TABLE user_directory_search ( + user_id text NOT NULL, + vector tsvector +); +CREATE TABLE user_directory_stream_pos ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint, + CONSTRAINT user_directory_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) +); +CREATE TABLE user_external_ids ( + auth_provider text NOT NULL, + external_id text NOT NULL, + user_id text NOT NULL +); +CREATE TABLE user_filters ( + user_id text NOT NULL, + filter_id bigint NOT NULL, + filter_json bytea NOT NULL +); +CREATE SEQUENCE user_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE user_ips ( + user_id text NOT NULL, + access_token text NOT NULL, + device_id text, + ip text NOT NULL, + user_agent text NOT NULL, + last_seen bigint NOT NULL +); +CREATE TABLE user_signature_stream ( + stream_id bigint NOT NULL, + from_user_id text NOT NULL, + user_ids text NOT NULL +); +CREATE TABLE user_stats_current ( + user_id text NOT NULL, + joined_rooms bigint NOT NULL, + completed_delta_stream_id bigint NOT NULL +); +CREATE TABLE user_threepid_id_server ( + user_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + id_server text NOT NULL +); +CREATE TABLE user_threepids ( + user_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + validated_at bigint NOT NULL, + added_at bigint NOT NULL +); +CREATE TABLE users ( + name text, + password_hash text, + creation_ts bigint, + admin smallint DEFAULT 0 NOT NULL, + upgrade_ts bigint, + is_guest smallint DEFAULT 0 NOT NULL, + appservice_id text, + consent_version text, + consent_server_notice_sent text, + user_type text, + deactivated smallint DEFAULT 0 NOT NULL, + shadow_banned boolean, + consent_ts bigint +); +CREATE TABLE users_in_public_rooms ( + user_id text NOT NULL, + room_id text NOT NULL +); +CREATE TABLE users_pending_deactivation ( + user_id text NOT NULL +); +CREATE TABLE users_to_send_full_presence_to ( + user_id text NOT NULL, + presence_stream_id bigint +); +CREATE TABLE users_who_share_private_rooms ( + user_id text NOT NULL, + other_user_id text NOT NULL, + room_id text NOT NULL +); +CREATE TABLE worker_locks ( + lock_name text NOT NULL, + lock_key text NOT NULL, + instance_name text NOT NULL, + token text NOT NULL, + last_renewed_ts bigint NOT NULL +); +ALTER TABLE ONLY instance_map ALTER COLUMN instance_id SET DEFAULT nextval('instance_map_instance_id_seq'::regclass); +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_pkey PRIMARY KEY (id); +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_token_key UNIQUE (token); +ALTER TABLE ONLY account_data + ADD CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type); +ALTER TABLE ONLY account_validity + ADD CONSTRAINT account_validity_pkey PRIMARY KEY (user_id); +ALTER TABLE ONLY application_services_state + ADD CONSTRAINT application_services_state_pkey PRIMARY KEY (as_id); +ALTER TABLE ONLY application_services_txns + ADD CONSTRAINT application_services_txns_as_id_txn_id_key UNIQUE (as_id, txn_id); +ALTER TABLE ONLY appservice_stream_position + ADD CONSTRAINT appservice_stream_position_lock_key UNIQUE (lock); +ALTER TABLE ONLY current_state_events + ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY current_state_events + ADD CONSTRAINT current_state_events_room_id_type_state_key_key UNIQUE (room_id, type, state_key); +ALTER TABLE ONLY dehydrated_devices + ADD CONSTRAINT dehydrated_devices_pkey PRIMARY KEY (user_id); +ALTER TABLE ONLY destination_rooms + ADD CONSTRAINT destination_rooms_pkey PRIMARY KEY (destination, room_id); +ALTER TABLE ONLY destinations + ADD CONSTRAINT destinations_pkey PRIMARY KEY (destination); +ALTER TABLE ONLY devices + ADD CONSTRAINT device_uniqueness UNIQUE (user_id, device_id); +ALTER TABLE ONLY e2e_device_keys_json + ADD CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id); +ALTER TABLE ONLY e2e_fallback_keys_json + ADD CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm); +ALTER TABLE ONLY e2e_one_time_keys_json + ADD CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id); +ALTER TABLE ONLY event_auth_chain_to_calculate + ADD CONSTRAINT event_auth_chain_to_calculate_pkey PRIMARY KEY (event_id); +ALTER TABLE ONLY event_auth_chains + ADD CONSTRAINT event_auth_chains_pkey PRIMARY KEY (event_id); +ALTER TABLE ONLY event_backward_extremities + ADD CONSTRAINT event_backward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); +ALTER TABLE ONLY event_expiry + ADD CONSTRAINT event_expiry_pkey PRIMARY KEY (event_id); +ALTER TABLE ONLY event_forward_extremities + ADD CONSTRAINT event_forward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); +ALTER TABLE ONLY event_push_actions + ADD CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag); +ALTER TABLE ONLY event_json + ADD CONSTRAINT event_json_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY event_labels + ADD CONSTRAINT event_labels_pkey PRIMARY KEY (event_id, label); +ALTER TABLE ONLY event_push_summary_last_receipt_stream_id + ADD CONSTRAINT event_push_summary_last_receipt_stream_id_lock_key UNIQUE (lock); +ALTER TABLE ONLY event_push_summary_stream_ordering + ADD CONSTRAINT event_push_summary_stream_ordering_lock_key UNIQUE (lock); +ALTER TABLE ONLY event_reports + ADD CONSTRAINT event_reports_pkey PRIMARY KEY (id); +ALTER TABLE ONLY event_to_state_groups + ADD CONSTRAINT event_to_state_groups_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY events + ADD CONSTRAINT events_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY ex_outlier_stream + ADD CONSTRAINT ex_outlier_stream_pkey PRIMARY KEY (event_stream_ordering); +ALTER TABLE ONLY instance_map + ADD CONSTRAINT instance_map_pkey PRIMARY KEY (instance_id); +ALTER TABLE ONLY local_media_repository + ADD CONSTRAINT local_media_repository_media_id_key UNIQUE (media_id); +ALTER TABLE ONLY user_threepids + ADD CONSTRAINT medium_address UNIQUE (medium, address); +ALTER TABLE ONLY open_id_tokens + ADD CONSTRAINT open_id_tokens_pkey PRIMARY KEY (token); +ALTER TABLE ONLY partial_state_events + ADD CONSTRAINT partial_state_events_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY partial_state_rooms + ADD CONSTRAINT partial_state_rooms_pkey PRIMARY KEY (room_id); +ALTER TABLE ONLY partial_state_rooms_servers + ADD CONSTRAINT partial_state_rooms_servers_room_id_server_name_key UNIQUE (room_id, server_name); +ALTER TABLE ONLY presence + ADD CONSTRAINT presence_user_id_key UNIQUE (user_id); +ALTER TABLE ONLY profiles + ADD CONSTRAINT profiles_user_id_key UNIQUE (user_id); +ALTER TABLE ONLY push_rules_enable + ADD CONSTRAINT push_rules_enable_pkey PRIMARY KEY (id); +ALTER TABLE ONLY push_rules_enable + ADD CONSTRAINT push_rules_enable_user_name_rule_id_key UNIQUE (user_name, rule_id); +ALTER TABLE ONLY push_rules + ADD CONSTRAINT push_rules_pkey PRIMARY KEY (id); +ALTER TABLE ONLY push_rules + ADD CONSTRAINT push_rules_user_name_rule_id_key UNIQUE (user_name, rule_id); +ALTER TABLE ONLY pusher_throttle + ADD CONSTRAINT pusher_throttle_pkey PRIMARY KEY (pusher, room_id); +ALTER TABLE ONLY pushers + ADD CONSTRAINT pushers2_app_id_pushkey_user_name_key UNIQUE (app_id, pushkey, user_name); +ALTER TABLE ONLY pushers + ADD CONSTRAINT pushers2_pkey PRIMARY KEY (id); +ALTER TABLE ONLY receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id); +ALTER TABLE ONLY receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id); +ALTER TABLE ONLY receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id); +ALTER TABLE ONLY receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id); +ALTER TABLE ONLY received_transactions + ADD CONSTRAINT received_transactions_transaction_id_origin_key UNIQUE (transaction_id, origin); +ALTER TABLE ONLY redactions + ADD CONSTRAINT redactions_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY refresh_tokens + ADD CONSTRAINT refresh_tokens_pkey PRIMARY KEY (id); +ALTER TABLE ONLY refresh_tokens + ADD CONSTRAINT refresh_tokens_token_key UNIQUE (token); +ALTER TABLE ONLY registration_tokens + ADD CONSTRAINT registration_tokens_token_key UNIQUE (token); +ALTER TABLE ONLY rejections + ADD CONSTRAINT rejections_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY remote_media_cache + ADD CONSTRAINT remote_media_cache_media_origin_media_id_key UNIQUE (media_origin, media_id); +ALTER TABLE ONLY room_account_data + ADD CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type); +ALTER TABLE ONLY room_aliases + ADD CONSTRAINT room_aliases_room_alias_key UNIQUE (room_alias); +ALTER TABLE ONLY room_depth + ADD CONSTRAINT room_depth_room_id_key UNIQUE (room_id); +ALTER TABLE ONLY room_memberships + ADD CONSTRAINT room_memberships_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY room_retention + ADD CONSTRAINT room_retention_pkey PRIMARY KEY (room_id, event_id); +ALTER TABLE ONLY room_stats_current + ADD CONSTRAINT room_stats_current_pkey PRIMARY KEY (room_id); +ALTER TABLE ONLY room_tags_revisions + ADD CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id); +ALTER TABLE ONLY room_tags + ADD CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag); +ALTER TABLE ONLY rooms + ADD CONSTRAINT rooms_pkey PRIMARY KEY (room_id); +ALTER TABLE ONLY server_keys_json + ADD CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server); +ALTER TABLE ONLY server_signature_keys + ADD CONSTRAINT server_signature_keys_server_name_key_id_key UNIQUE (server_name, key_id); +ALTER TABLE ONLY sessions + ADD CONSTRAINT sessions_session_type_session_id_key UNIQUE (session_type, session_id); +ALTER TABLE ONLY state_events + ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); +ALTER TABLE ONLY stats_incremental_position + ADD CONSTRAINT stats_incremental_position_lock_key UNIQUE (lock); +ALTER TABLE ONLY threepid_validation_session + ADD CONSTRAINT threepid_validation_session_pkey PRIMARY KEY (session_id); +ALTER TABLE ONLY threepid_validation_token + ADD CONSTRAINT threepid_validation_token_pkey PRIMARY KEY (token); +ALTER TABLE ONLY ui_auth_sessions_credentials + ADD CONSTRAINT ui_auth_sessions_credentials_session_id_stage_type_key UNIQUE (session_id, stage_type); +ALTER TABLE ONLY ui_auth_sessions_ips + ADD CONSTRAINT ui_auth_sessions_ips_session_id_ip_user_agent_key UNIQUE (session_id, ip, user_agent); +ALTER TABLE ONLY ui_auth_sessions + ADD CONSTRAINT ui_auth_sessions_session_id_key UNIQUE (session_id); +ALTER TABLE ONLY user_directory_stream_pos + ADD CONSTRAINT user_directory_stream_pos_lock_key UNIQUE (lock); +ALTER TABLE ONLY user_external_ids + ADD CONSTRAINT user_external_ids_auth_provider_external_id_key UNIQUE (auth_provider, external_id); +ALTER TABLE ONLY user_stats_current + ADD CONSTRAINT user_stats_current_pkey PRIMARY KEY (user_id); +ALTER TABLE ONLY users + ADD CONSTRAINT users_name_key UNIQUE (name); +ALTER TABLE ONLY users_to_send_full_presence_to + ADD CONSTRAINT users_to_send_full_presence_to_pkey PRIMARY KEY (user_id); +CREATE INDEX access_tokens_device_id ON access_tokens USING btree (user_id, device_id); +CREATE INDEX account_data_stream_id ON account_data USING btree (user_id, stream_id); +CREATE INDEX application_services_txns_id ON application_services_txns USING btree (as_id); +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list USING btree (appservice_id, network_id, room_id); +CREATE INDEX batch_events_batch_id ON batch_events USING btree (batch_id); +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms USING btree (room_id); +CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance USING btree (stream_id); +CREATE INDEX cache_invalidation_stream_by_instance_instance_index ON cache_invalidation_stream_by_instance USING btree (instance_name, stream_id); +CREATE UNIQUE INDEX chunk_events_event_id ON batch_events USING btree (event_id); +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream USING btree (stream_id); +CREATE INDEX current_state_events_member_index ON current_state_events USING btree (state_key) WHERE (type = 'm.room.member'::text); +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers USING btree (stream_id); +CREATE INDEX destination_rooms_room_id ON destination_rooms USING btree (room_id); +CREATE INDEX device_auth_providers_devices ON device_auth_providers USING btree (user_id, device_id); +CREATE INDEX device_auth_providers_sessions ON device_auth_providers USING btree (auth_provider_id, auth_provider_session_id); +CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox USING btree (origin, message_id); +CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox USING btree (destination, stream_id); +CREATE INDEX device_federation_outbox_id ON device_federation_outbox USING btree (stream_id); +CREATE INDEX device_inbox_stream_id_user_id ON device_inbox USING btree (stream_id, user_id); +CREATE INDEX device_inbox_user_stream_id ON device_inbox USING btree (user_id, device_id, stream_id); +CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room USING btree (stream_id, room_id); +CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room USING btree (stream_id) WHERE (NOT converted_to_destinations); +CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx ON device_lists_outbound_last_success USING btree (destination, user_id); +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes USING btree (destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes USING btree (stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes USING btree (destination, user_id); +CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache USING btree (user_id, device_id); +CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties USING btree (user_id); +CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync USING btree (user_id); +CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync USING btree (added_ts); +CREATE INDEX device_lists_stream_id ON device_lists_stream USING btree (stream_id, user_id); +CREATE INDEX device_lists_stream_user_id ON device_lists_stream USING btree (user_id, device_id); +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys USING btree (user_id, keytype, stream_id); +CREATE UNIQUE INDEX e2e_cross_signing_keys_stream_idx ON e2e_cross_signing_keys USING btree (stream_id); +CREATE INDEX e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures USING btree (user_id, target_user_id, target_device_id); +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions USING btree (user_id, version); +CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys USING btree (user_id, version, room_id, session_id); +CREATE UNIQUE INDEX erased_users_user ON erased_users USING btree (user_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities USING btree (event_id); +CREATE INDEX ev_b_extrem_room ON event_backward_extremities USING btree (room_id); +CREATE INDEX ev_edges_prev_id ON event_edges USING btree (prev_event_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities USING btree (event_id); +CREATE INDEX ev_extrem_room ON event_forward_extremities USING btree (room_id); +CREATE INDEX evauth_edges_id ON event_auth USING btree (event_id); +CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links USING btree (origin_chain_id, target_chain_id); +CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate USING btree (room_id); +CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains USING btree (chain_id, sequence_number); +CREATE INDEX event_contains_url_index ON events USING btree (room_id, topological_ordering, stream_ordering) WHERE ((contains_url = true) AND (outlier = false)); +CREATE UNIQUE INDEX event_edges_event_id_prev_event_id_idx ON event_edges USING btree (event_id, prev_event_id); +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry USING btree (expiry_ts); +CREATE INDEX event_labels_room_id_label_idx ON event_labels USING btree (room_id, label, topological_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering) WHERE (highlight = 1); +CREATE INDEX event_push_actions_rm_tokens ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering); +CREATE INDEX event_push_actions_room_id_user_id ON event_push_actions USING btree (room_id, user_id); +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging USING btree (event_id); +CREATE INDEX event_push_actions_stream_highlight_index ON event_push_actions USING btree (highlight, stream_ordering) WHERE (highlight = 0); +CREATE INDEX event_push_actions_stream_ordering ON event_push_actions USING btree (stream_ordering, user_id); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions USING btree (user_id, stream_ordering); +CREATE UNIQUE INDEX event_push_summary_unique_index ON event_push_summary USING btree (user_id, room_id); +CREATE UNIQUE INDEX event_push_summary_unique_index2 ON event_push_summary USING btree (user_id, room_id, thread_id); +CREATE UNIQUE INDEX event_relations_id ON event_relations USING btree (event_id); +CREATE INDEX event_relations_relates ON event_relations USING btree (relates_to_id, relation_type, aggregation_key); +CREATE INDEX event_search_ev_ridx ON event_search USING btree (room_id); +CREATE UNIQUE INDEX event_search_event_id_idx ON event_search USING btree (event_id); +CREATE INDEX event_search_fts_idx ON event_search USING gin (vector); +CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups USING btree (state_group); +CREATE UNIQUE INDEX event_txn_id_event_id ON event_txn_id USING btree (event_id); +CREATE INDEX event_txn_id_ts ON event_txn_id USING btree (inserted_ts); +CREATE UNIQUE INDEX event_txn_id_txn_id ON event_txn_id USING btree (room_id, user_id, token_id, txn_id); +CREATE INDEX events_order_room ON events USING btree (room_id, topological_ordering, stream_ordering); +CREATE INDEX events_room_stream ON events USING btree (room_id, stream_ordering); +CREATE UNIQUE INDEX events_stream_ordering ON events USING btree (stream_ordering); +CREATE INDEX events_ts ON events USING btree (origin_server_ts, stream_ordering); +CREATE UNIQUE INDEX federation_inbound_events_staging_instance_event ON federation_inbound_events_staging USING btree (origin, event_id); +CREATE INDEX federation_inbound_events_staging_room ON federation_inbound_events_staging USING btree (room_id, received_ts); +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position USING btree (type, instance_name); +CREATE INDEX ignored_users_ignored_user_id ON ignored_users USING btree (ignored_user_id); +CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users USING btree (ignorer_user_id, ignored_user_id); +CREATE INDEX insertion_event_edges_event_id ON insertion_event_edges USING btree (event_id); +CREATE INDEX insertion_event_edges_insertion_prev_event_id ON insertion_event_edges USING btree (insertion_prev_event_id); +CREATE INDEX insertion_event_edges_insertion_room_id ON insertion_event_edges USING btree (room_id); +CREATE UNIQUE INDEX insertion_event_extremities_event_id ON insertion_event_extremities USING btree (event_id); +CREATE INDEX insertion_event_extremities_room_id ON insertion_event_extremities USING btree (room_id); +CREATE UNIQUE INDEX insertion_events_event_id ON insertion_events USING btree (event_id); +CREATE INDEX insertion_events_next_batch_id ON insertion_events USING btree (next_batch_id); +CREATE UNIQUE INDEX instance_map_idx ON instance_map USING btree (instance_name); +CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership USING btree (user_id, room_id); +CREATE INDEX local_current_membership_room_idx ON local_current_membership USING btree (room_id); +CREATE UNIQUE INDEX local_media_repository_thumbn_media_id_width_height_method_key ON local_media_repository_thumbnails USING btree (media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method); +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails USING btree (media_id); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache USING btree (url, download_ts); +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache USING btree (expires_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache USING btree (media_id); +CREATE INDEX local_media_repository_url_idx ON local_media_repository USING btree (created_ts) WHERE (url_cache IS NOT NULL); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users USING btree ("timestamp"); +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users USING btree (user_id); +CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens USING btree (ts_valid_until_ms); +CREATE INDEX partial_state_events_room_id_idx ON partial_state_events USING btree (room_id); +CREATE INDEX presence_stream_id ON presence_stream USING btree (stream_id, user_id); +CREATE INDEX presence_stream_state_not_offline_idx ON presence_stream USING btree (state) WHERE (state <> 'offline'::text); +CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); +CREATE INDEX public_room_index ON rooms USING btree (is_public); +CREATE INDEX push_rules_enable_user_name ON push_rules_enable USING btree (user_name); +CREATE INDEX push_rules_stream_id ON push_rules_stream USING btree (stream_id); +CREATE INDEX push_rules_stream_user_stream_id ON push_rules_stream USING btree (user_id, stream_id); +CREATE INDEX push_rules_user_name ON push_rules USING btree (user_name); +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override USING btree (user_id); +CREATE UNIQUE INDEX receipts_graph_unique_index ON receipts_graph USING btree (room_id, receipt_type, user_id) WHERE (thread_id IS NULL); +CREATE INDEX receipts_linearized_id ON receipts_linearized USING btree (stream_id); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized USING btree (room_id, stream_id); +CREATE UNIQUE INDEX receipts_linearized_unique_index ON receipts_linearized USING btree (room_id, receipt_type, user_id) WHERE (thread_id IS NULL); +CREATE INDEX receipts_linearized_user ON receipts_linearized USING btree (user_id); +CREATE INDEX received_transactions_ts ON received_transactions USING btree (ts); +CREATE INDEX redactions_have_censored_ts ON redactions USING btree (received_ts) WHERE (NOT have_censored); +CREATE INDEX redactions_redacts ON redactions USING btree (redacts); +CREATE INDEX refresh_tokens_next_token_id ON refresh_tokens USING btree (next_token_id) WHERE (next_token_id IS NOT NULL); +CREATE UNIQUE INDEX remote_media_repository_thumbn_media_origin_id_width_height_met ON remote_media_cache_thumbnails USING btree (media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method); +CREATE INDEX room_account_data_stream_id ON room_account_data USING btree (user_id, stream_id); +CREATE INDEX room_alias_servers_alias ON room_alias_servers USING btree (room_alias); +CREATE INDEX room_aliases_id ON room_aliases USING btree (room_id); +CREATE INDEX room_memberships_room_id ON room_memberships USING btree (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships USING btree (user_id); +CREATE INDEX room_memberships_user_room_forgotten ON room_memberships USING btree (user_id, room_id) WHERE (forgotten = 1); +CREATE INDEX room_retention_max_lifetime_idx ON room_retention USING btree (max_lifetime); +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token USING btree (room_id); +CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state USING btree (room_id); +CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); +CREATE INDEX stream_ordering_to_exterm_rm_idx ON stream_ordering_to_exterm USING btree (room_id, stream_ordering); +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions USING btree (stream_name, instance_name); +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens USING btree (medium, address); +CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token USING btree (session_id); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits USING btree ("timestamp"); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits USING btree (user_id, "timestamp"); +CREATE INDEX user_directory_room_idx ON user_directory USING btree (room_id); +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin (vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search USING btree (user_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory USING btree (user_id); +CREATE INDEX user_external_ids_user_id_idx ON user_external_ids USING btree (user_id); +CREATE UNIQUE INDEX user_filters_unique ON user_filters USING btree (user_id, filter_id); +CREATE INDEX user_ips_device_id ON user_ips USING btree (user_id, device_id, last_seen); +CREATE INDEX user_ips_last_seen ON user_ips USING btree (user_id, last_seen); +CREATE INDEX user_ips_last_seen_only ON user_ips USING btree (last_seen); +CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips USING btree (user_id, access_token, ip); +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream USING btree (stream_id); +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server USING btree (user_id, medium, address, id_server); +CREATE INDEX user_threepids_medium_address ON user_threepids USING btree (medium, address); +CREATE INDEX user_threepids_user_id ON user_threepids USING btree (user_id); +CREATE INDEX users_creation_ts ON users USING btree (creation_ts); +CREATE INDEX users_have_local_media ON local_media_repository USING btree (user_id, created_ts); +CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms USING btree (room_id); +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms USING btree (user_id, room_id); +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms USING btree (other_user_id); +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms USING btree (room_id); +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); +CREATE UNIQUE INDEX worker_locks_key ON worker_locks USING btree (lock_name, lock_key); +CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events FOR EACH ROW EXECUTE PROCEDURE check_partial_state_events(); +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_refresh_token_id_fkey FOREIGN KEY (refresh_token_id) REFERENCES refresh_tokens(id) ON DELETE CASCADE; +ALTER TABLE ONLY destination_rooms + ADD CONSTRAINT destination_rooms_destination_fkey FOREIGN KEY (destination) REFERENCES destinations(destination); +ALTER TABLE ONLY destination_rooms + ADD CONSTRAINT destination_rooms_room_id_fkey FOREIGN KEY (room_id) REFERENCES rooms(room_id); +ALTER TABLE ONLY event_edges + ADD CONSTRAINT event_edges_event_id_fkey FOREIGN KEY (event_id) REFERENCES events(event_id); +ALTER TABLE ONLY event_txn_id + ADD CONSTRAINT event_txn_id_event_id_fkey FOREIGN KEY (event_id) REFERENCES events(event_id) ON DELETE CASCADE; +ALTER TABLE ONLY event_txn_id + ADD CONSTRAINT event_txn_id_token_id_fkey FOREIGN KEY (token_id) REFERENCES access_tokens(id) ON DELETE CASCADE; +ALTER TABLE ONLY partial_state_events + ADD CONSTRAINT partial_state_events_event_id_fkey FOREIGN KEY (event_id) REFERENCES events(event_id); +ALTER TABLE ONLY partial_state_events + ADD CONSTRAINT partial_state_events_room_id_fkey FOREIGN KEY (room_id) REFERENCES partial_state_rooms(room_id); +ALTER TABLE ONLY partial_state_rooms + ADD CONSTRAINT partial_state_rooms_room_id_fkey FOREIGN KEY (room_id) REFERENCES rooms(room_id); +ALTER TABLE ONLY partial_state_rooms_servers + ADD CONSTRAINT partial_state_rooms_servers_room_id_fkey FOREIGN KEY (room_id) REFERENCES partial_state_rooms(room_id); +ALTER TABLE ONLY refresh_tokens + ADD CONSTRAINT refresh_tokens_next_token_id_fkey FOREIGN KEY (next_token_id) REFERENCES refresh_tokens(id) ON DELETE CASCADE; +ALTER TABLE ONLY ui_auth_sessions_credentials + ADD CONSTRAINT ui_auth_sessions_credentials_session_id_fkey FOREIGN KEY (session_id) REFERENCES ui_auth_sessions(session_id); +ALTER TABLE ONLY ui_auth_sessions_ips + ADD CONSTRAINT ui_auth_sessions_ips_session_id_fkey FOREIGN KEY (session_id) REFERENCES ui_auth_sessions(session_id); +ALTER TABLE ONLY users_to_send_full_presence_to + ADD CONSTRAINT users_to_send_full_presence_to_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(name); +INSERT INTO appservice_stream_position VALUES ('X', 0); +INSERT INTO event_push_summary_last_receipt_stream_id VALUES ('X', 0); +INSERT INTO event_push_summary_stream_ordering VALUES ('X', 0); +INSERT INTO federation_stream_position VALUES ('federation', -1, 'master'); +INSERT INTO federation_stream_position VALUES ('events', -1, 'master'); +INSERT INTO stats_incremental_position VALUES ('X', 1); +INSERT INTO user_directory_stream_pos VALUES ('X', 1); +SELECT pg_catalog.setval('account_data_sequence', 1, true); +SELECT pg_catalog.setval('application_services_txn_id_seq', 1, false); +SELECT pg_catalog.setval('cache_invalidation_stream_seq', 1, true); +SELECT pg_catalog.setval('device_inbox_sequence', 1, true); +SELECT pg_catalog.setval('event_auth_chain_id', 1, false); +SELECT pg_catalog.setval('events_backfill_stream_seq', 1, true); +SELECT pg_catalog.setval('events_stream_seq', 1, true); +SELECT pg_catalog.setval('instance_map_instance_id_seq', 1, false); +SELECT pg_catalog.setval('presence_stream_sequence', 1, true); +SELECT pg_catalog.setval('receipts_sequence', 1, true); +SELECT pg_catalog.setval('user_id_seq', 1, false); diff --git a/synapse/storage/schema/main/full_schemas/72/full.sql.sqlite b/synapse/storage/schema/main/full_schemas/72/full.sql.sqlite new file mode 100644 index 0000000000..d403baf1fb --- /dev/null +++ b/synapse/storage/schema/main/full_schemas/72/full.sql.sqlite @@ -0,0 +1,646 @@ +CREATE TABLE application_services_txns( as_id TEXT NOT NULL, txn_id INTEGER NOT NULL, event_ids TEXT NOT NULL, UNIQUE(as_id, txn_id) ); +CREATE INDEX application_services_txns_id ON application_services_txns ( as_id ); +CREATE TABLE presence( user_id TEXT NOT NULL, state VARCHAR(20), status_msg TEXT, mtime BIGINT, UNIQUE (user_id) ); +CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, deactivated SMALLINT DEFAULT 0 NOT NULL, shadow_banned BOOLEAN, consent_ts bigint, UNIQUE(name) ); +CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); +CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); +CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); +CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER , failure_ts BIGINT, last_successful_stream_ordering BIGINT); +CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, instance_name TEXT, state_key TEXT DEFAULT NULL, rejection_reason TEXT DEFAULT NULL, UNIQUE (event_id) ); +CREATE INDEX events_order_room ON events ( room_id, topological_ordering, stream_ordering ); +CREATE TABLE event_json( event_id TEXT NOT NULL, room_id TEXT NOT NULL, internal_metadata TEXT NOT NULL, json TEXT NOT NULL, format_version INTEGER, UNIQUE (event_id) ); +CREATE TABLE state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, prev_state TEXT, UNIQUE (event_id) ); +CREATE TABLE current_state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, membership TEXT, UNIQUE (event_id), UNIQUE (room_id, type, state_key) ); +CREATE TABLE room_memberships( event_id TEXT NOT NULL, user_id TEXT NOT NULL, sender TEXT NOT NULL, room_id TEXT NOT NULL, membership TEXT NOT NULL, forgotten INTEGER DEFAULT 0, display_name TEXT, avatar_url TEXT, UNIQUE (event_id) ); +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); +CREATE TABLE rooms( room_id TEXT PRIMARY KEY NOT NULL, is_public BOOL, creator TEXT , room_version TEXT, has_auth_chain_index BOOLEAN); +CREATE TABLE server_signature_keys( server_name TEXT, key_id TEXT, from_server TEXT, ts_added_ms BIGINT, verify_key bytea, ts_valid_until_ms BIGINT, UNIQUE (server_name, key_id) ); +CREATE TABLE rejections( event_id TEXT NOT NULL, reason TEXT NOT NULL, last_check TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE push_rules ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, priority_class SMALLINT NOT NULL, priority INTEGER NOT NULL DEFAULT 0, conditions TEXT NOT NULL, actions TEXT NOT NULL, UNIQUE(user_name, rule_id) ); +CREATE INDEX push_rules_user_name on push_rules (user_name); +CREATE TABLE push_rules_enable ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, enabled SMALLINT, UNIQUE(user_name, rule_id) ); +CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); +CREATE TABLE event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); +CREATE TABLE event_backward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); +CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); +CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0, UNIQUE (media_id) ); +CREATE TABLE remote_media_cache ( media_origin TEXT, media_id TEXT, media_type TEXT, created_ts BIGINT, upload_name TEXT, media_length INTEGER, filesystem_id TEXT, last_access_ts BIGINT, quarantined_by TEXT, UNIQUE (media_origin, media_id) ); +CREATE TABLE redactions ( event_id TEXT NOT NULL, redacts TEXT NOT NULL, have_censored BOOL NOT NULL DEFAULT false, received_ts BIGINT, UNIQUE (event_id) ); +CREATE INDEX redactions_redacts ON redactions (redacts); +CREATE TABLE room_aliases( room_alias TEXT NOT NULL, room_id TEXT NOT NULL, creator TEXT, UNIQUE (room_alias) ); +CREATE INDEX room_aliases_id ON room_aliases(room_id); +CREATE TABLE room_alias_servers( room_alias TEXT NOT NULL, server TEXT NOT NULL ); +CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); +CREATE TABLE IF NOT EXISTS "server_keys_json" ( server_name TEXT NOT NULL, key_id TEXT NOT NULL, from_server TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, ts_valid_until_ms BIGINT NOT NULL, key_json bytea NOT NULL, CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) ); +CREATE TABLE e2e_device_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) ); +CREATE TABLE e2e_one_time_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, algorithm TEXT NOT NULL, key_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); +CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); +CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) +/* event_search(event_id,room_id,sender,"key",value) */; +CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); +CREATE TABLE room_tags_revisions ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, stream_id BIGINT NOT NULL, instance_name TEXT, CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) ); +CREATE TABLE account_data( user_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, instance_name TEXT, CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) ); +CREATE TABLE room_account_data( user_id TEXT NOT NULL, room_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, instance_name TEXT, CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) ); +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); +CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering); +CREATE TABLE event_push_actions( room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, profile_tag VARCHAR(32), actions TEXT NOT NULL, topological_ordering BIGINT, stream_ordering BIGINT, notif SMALLINT, highlight SMALLINT, unread SMALLINT, thread_id TEXT, CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) ); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); +CREATE INDEX events_room_stream on events(room_id, stream_ordering); +CREATE INDEX public_room_index on rooms(is_public); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); +CREATE TABLE presence_stream( stream_id BIGINT, user_id TEXT, state TEXT, last_active_ts BIGINT, last_federation_update_ts BIGINT, last_user_sync_ts BIGINT, status_msg TEXT, currently_active BOOLEAN , instance_name TEXT); +CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); +CREATE INDEX presence_stream_user_id ON presence_stream(user_id); +CREATE TABLE push_rules_stream( stream_id BIGINT NOT NULL, event_stream_ordering BIGINT NOT NULL, user_id TEXT NOT NULL, rule_id TEXT NOT NULL, op TEXT NOT NULL, priority_class SMALLINT, priority INTEGER, conditions TEXT, actions TEXT ); +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); +CREATE TABLE ex_outlier_stream( event_stream_ordering BIGINT PRIMARY KEY NOT NULL, event_id TEXT NOT NULL, state_group BIGINT NOT NULL , instance_name TEXT); +CREATE TABLE threepid_guest_access_tokens( medium TEXT, address TEXT, guest_access_token TEXT, first_inviter TEXT ); +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); +CREATE TABLE open_id_tokens ( token TEXT NOT NULL PRIMARY KEY, ts_valid_until_ms bigint NOT NULL, user_id TEXT NOT NULL, UNIQUE (token) ); +CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); +CREATE TABLE pusher_throttle( pusher BIGINT NOT NULL, room_id TEXT NOT NULL, last_sent_ts BIGINT, throttle_ms BIGINT, PRIMARY KEY (pusher, room_id) ); +CREATE TABLE event_reports( id BIGINT NOT NULL PRIMARY KEY, received_ts BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, reason TEXT, content TEXT ); +CREATE TABLE appservice_stream_position( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT, CHECK (Lock='X') ); +CREATE TABLE device_inbox ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, stream_id BIGINT NOT NULL, message_json TEXT NOT NULL , instance_name TEXT); +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX received_transactions_ts ON received_transactions(ts); +CREATE TABLE device_federation_outbox ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, queued_ts BIGINT NOT NULL, messages_json TEXT NOT NULL , instance_name TEXT); +CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox(destination, stream_id); +CREATE TABLE device_federation_inbox ( origin TEXT NOT NULL, message_id TEXT NOT NULL, received_ts BIGINT NOT NULL , instance_name TEXT); +CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox(origin, message_id); +CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); +CREATE TABLE IF NOT EXISTS "event_auth"( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); +CREATE TABLE appservice_room_list( appservice_id TEXT NOT NULL, network_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( appservice_id, network_id, room_id ); +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); +CREATE TABLE federation_stream_position( type TEXT NOT NULL, stream_id INTEGER NOT NULL , instance_name TEXT NOT NULL DEFAULT 'master'); +CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, content TEXT NOT NULL ); +CREATE TABLE device_lists_remote_extremeties ( user_id TEXT NOT NULL, stream_id TEXT NOT NULL ); +CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL ); +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); +CREATE TABLE device_lists_outbound_pokes ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, sent BOOLEAN NOT NULL, ts BIGINT NOT NULL , opentracing_context TEXT); +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); +CREATE TABLE event_push_summary ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, notif_count BIGINT NOT NULL, stream_ordering BIGINT NOT NULL , unread_count BIGINT, last_receipt_stream_ordering BIGINT, thread_id TEXT); +CREATE TABLE event_push_summary_stream_ordering ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT NOT NULL, CHECK (Lock='X') ); +CREATE TABLE IF NOT EXISTS "pushers" ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, access_token BIGINT DEFAULT NULL, profile_tag TEXT NOT NULL, kind TEXT NOT NULL, app_id TEXT NOT NULL, app_display_name TEXT NOT NULL, device_display_name TEXT NOT NULL, pushkey TEXT NOT NULL, ts BIGINT NOT NULL, lang TEXT, data TEXT, last_stream_ordering INTEGER, last_success BIGINT, failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ); +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); +CREATE TABLE ratelimit_override ( user_id TEXT NOT NULL, messages_per_second BIGINT, burst_count BIGINT ); +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); +CREATE TABLE current_state_delta_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT, prev_event_id TEXT , instance_name TEXT); +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); +CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); +CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) +/* user_directory_search(user_id,value) */; +CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); +CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); +CREATE TABLE IF NOT EXISTS "deleted_pushers" ( stream_id BIGINT NOT NULL, app_id TEXT NOT NULL, pushkey TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); +CREATE TABLE IF NOT EXISTS "user_directory" ( user_id TEXT NOT NULL, room_id TEXT, display_name TEXT, avatar_url TEXT ); +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); +CREATE TABLE event_push_actions_staging ( event_id TEXT NOT NULL, user_id TEXT NOT NULL, actions TEXT NOT NULL, notif SMALLINT NOT NULL, highlight SMALLINT NOT NULL , unread SMALLINT, thread_id TEXT); +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); +CREATE TABLE users_pending_deactivation ( user_id TEXT NOT NULL ); +CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL , user_agent TEXT); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); +CREATE TABLE erased_users ( user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); +CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, timestamp BIGINT NOT NULL ); +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); +CREATE TABLE IF NOT EXISTS "e2e_room_keys_versions" ( user_id TEXT NOT NULL, version BIGINT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL , etag BIGINT); +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); +CREATE TABLE IF NOT EXISTS "e2e_room_keys" ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, version BIGINT NOT NULL, first_message_index INT, forwarded_count INT, is_verified BOOLEAN, session_data TEXT NOT NULL ); +CREATE TABLE users_who_share_private_rooms ( user_id TEXT NOT NULL, other_user_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); +CREATE TABLE user_threepid_id_server ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, id_server TEXT NOT NULL ); +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( user_id, medium, address, id_server ); +CREATE TABLE users_in_public_rooms ( user_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); +CREATE TABLE account_validity ( user_id TEXT PRIMARY KEY, expiration_ts_ms BIGINT NOT NULL, email_sent BOOLEAN NOT NULL, renewal_token TEXT , token_used_ts_ms BIGINT); +CREATE TABLE event_relations ( event_id TEXT NOT NULL, relates_to_id TEXT NOT NULL, relation_type TEXT NOT NULL, aggregation_key TEXT ); +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); +CREATE TABLE room_stats_earliest_token ( room_id TEXT NOT NULL, token BIGINT NOT NULL ); +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); +CREATE INDEX user_ips_device_id ON user_ips (user_id, device_id, last_seen); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering); +CREATE INDEX device_inbox_stream_id_user_id ON device_inbox (stream_id, user_id); +CREATE INDEX device_lists_stream_user_id ON device_lists_stream (user_id, device_id); +CREATE INDEX user_ips_last_seen ON user_ips (user_id, last_seen); +CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); +CREATE INDEX users_creation_ts ON users (creation_ts); +CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); +CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); +CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); +CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); +CREATE TABLE threepid_validation_session ( + session_id TEXT PRIMARY KEY, + medium TEXT NOT NULL, + address TEXT NOT NULL, + client_secret TEXT NOT NULL, + last_send_attempt BIGINT NOT NULL, + validated_at BIGINT +); +CREATE TABLE threepid_validation_token ( + token TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + next_link TEXT, + expires BIGINT NOT NULL +); +CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token(session_id); +CREATE TABLE event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); +CREATE TABLE event_labels ( + event_id TEXT, + label TEXT, + room_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + PRIMARY KEY(event_id, label) +); +CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); +CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id); +CREATE TABLE IF NOT EXISTS "devices" ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + last_seen BIGINT, + ip TEXT, + user_agent TEXT, + hidden BOOLEAN DEFAULT 0, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); +CREATE TABLE room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); +CREATE TABLE e2e_cross_signing_keys ( + user_id TEXT NOT NULL, + -- the type of cross-signing key (master, user_signing, or self_signing) + keytype TEXT NOT NULL, + -- the full key information, as a json-encoded dict + keydata TEXT NOT NULL, + -- for keeping the keys in order, so that we can fetch the latest one + stream_id BIGINT NOT NULL +); +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id); +CREATE TABLE e2e_cross_signing_signatures ( + -- user who did the signing + user_id TEXT NOT NULL, + -- key used to sign + key_id TEXT NOT NULL, + -- user who was signed + target_user_id TEXT NOT NULL, + -- device/key that was signed + target_device_id TEXT NOT NULL, + -- the actual signature + signature TEXT NOT NULL +); +CREATE TABLE user_signature_stream ( + -- uses the same stream ID as device list stream + stream_id BIGINT NOT NULL, + -- user who did the signing + from_user_id TEXT NOT NULL, + -- list of users who were signed, as a JSON array + user_ids TEXT NOT NULL +); +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); +CREATE INDEX e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); +CREATE TABLE stats_incremental_position ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); +CREATE TABLE room_stats_current ( + room_id TEXT NOT NULL PRIMARY KEY, + + -- These are absolute counts + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + + local_users_in_room INT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +, knocked_members INT); +CREATE TABLE user_stats_current ( + user_id TEXT NOT NULL PRIMARY KEY, + + joined_rooms BIGINT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); +CREATE TABLE room_stats_state ( + room_id TEXT NOT NULL, + name TEXT, + canonical_alias TEXT, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + avatar TEXT, + guest_access TEXT, + is_federatable BOOLEAN, + topic TEXT +, room_type TEXT); +CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); +CREATE TABLE IF NOT EXISTS "user_filters" ( user_id TEXT NOT NULL, filter_id BIGINT NOT NULL, filter_json BYTEA NOT NULL ); +CREATE UNIQUE INDEX user_filters_unique ON "user_filters" (user_id, filter_id); +CREATE TABLE user_external_ids ( + auth_provider TEXT NOT NULL, + external_id TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (auth_provider, external_id) +); +CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id); +CREATE TABLE device_lists_remote_resync ( + user_id TEXT NOT NULL, + added_ts BIGINT NOT NULL +); +CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id); +CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts); +CREATE TABLE local_current_membership ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + membership TEXT NOT NULL + ); +CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id); +CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id); +CREATE TABLE ui_auth_sessions( + session_id TEXT NOT NULL, -- The session ID passed to the client. + creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). + serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. + clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. + uri TEXT NOT NULL, -- The URI the UI authentication session is using. + method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. + -- The clientdict, uri, and method make up an tuple that must be immutable + -- throughout the lifetime of the UI Auth session. + description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. + UNIQUE (session_id) +); +CREATE TABLE ui_auth_sessions_credentials( + session_id TEXT NOT NULL, -- The corresponding UI Auth session. + stage_type TEXT NOT NULL, -- The stage type. + result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. + UNIQUE (session_id, stage_type), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); +CREATE TABLE IF NOT EXISTS "device_lists_outbound_last_success" ( destination TEXT NOT NULL, user_id TEXT NOT NULL, stream_id BIGINT NOT NULL ); +CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx ON "device_lists_outbound_last_success" (destination, user_id); +CREATE TABLE IF NOT EXISTS "local_media_repository_thumbnails" ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) ); +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); +CREATE TABLE IF NOT EXISTS "remote_media_cache_thumbnails" ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) ); +CREATE TABLE ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); +CREATE TABLE dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL -- JSON-encoded client-defined data +); +CREATE TABLE e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); +CREATE TABLE destination_rooms ( + -- the destination in question. + destination TEXT NOT NULL REFERENCES destinations (destination), + -- the ID of the room in question + room_id TEXT NOT NULL REFERENCES rooms (room_id), + -- the stream_ordering of the event + stream_ordering BIGINT NOT NULL, + PRIMARY KEY (destination, room_id) + -- We don't declare a foreign key on stream_ordering here because that'd mean + -- we'd need to either maintain an index (expensive) or do a table scan of + -- destination_rooms whenever we delete an event (also potentially expensive). + -- In addition to that, a foreign key on stream_ordering would be redundant + -- as this row doesn't need to refer to a specific event; if the event gets + -- deleted then it doesn't affect the validity of the stream_ordering here. +); +CREATE INDEX destination_rooms_room_id + ON destination_rooms (room_id); +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); +CREATE TABLE IF NOT EXISTS "access_tokens" ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + valid_until_ms BIGINT, + puppets_user_id TEXT, + last_validated BIGINT, refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE, used BOOLEAN, + UNIQUE(token) +); +CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); +CREATE TABLE IF NOT EXISTS "event_txn_id" ( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + token_id BIGINT NOT NULL, + txn_id TEXT NOT NULL, + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (token_id) + REFERENCES access_tokens (id) ON DELETE CASCADE +); +CREATE UNIQUE INDEX event_txn_id_event_id ON event_txn_id(event_id); +CREATE UNIQUE INDEX event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id); +CREATE INDEX event_txn_id_ts ON event_txn_id(inserted_ts); +CREATE TABLE ignored_users( ignorer_user_id TEXT NOT NULL, ignored_user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id); +CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id); +CREATE TABLE event_auth_chains ( + event_id TEXT PRIMARY KEY, + chain_id BIGINT NOT NULL, + sequence_number BIGINT NOT NULL +); +CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number); +CREATE TABLE event_auth_chain_links ( + origin_chain_id BIGINT NOT NULL, + origin_sequence_number BIGINT NOT NULL, + + target_chain_id BIGINT NOT NULL, + target_sequence_number BIGINT NOT NULL +); +CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id); +CREATE TABLE event_auth_chain_to_calculate ( + event_id TEXT PRIMARY KEY, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL +); +CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id); +CREATE TABLE users_to_send_full_presence_to( + -- The user ID to send full presence to. + user_id TEXT PRIMARY KEY, + -- A presence stream ID token - the current presence stream token when the row was last upserted. + -- If a user calls /sync and this token is part of the update they're to receive, we also include + -- full user presence in the response. + -- This allows multiple devices for a user to receive full presence whenever they next call /sync. + presence_stream_id BIGINT, + FOREIGN KEY (user_id) + REFERENCES users (name) +); +CREATE TABLE refresh_tokens ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + token TEXT NOT NULL, + -- When consumed, a new refresh token is generated, which is tracked by + -- this foreign key + next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE, expiry_ts BIGINT DEFAULT NULL, ultimate_session_expiry_ts BIGINT DEFAULT NULL, + UNIQUE(token) +); +CREATE TABLE worker_locks ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- We write the instance name to ease manual debugging, we don't ever read + -- from it. + -- Note: instance names aren't guarenteed to be unique. + instance_name TEXT NOT NULL, + -- A random string generated each time an instance takes out a lock. Used by + -- the instance to tell whether the lock is still held by it (e.g. in the + -- case where the process stalls for a long time the lock may time out and + -- be taken out by another instance, at which point the original instance + -- can tell it no longer holds the lock as the tokens no longer match). + token TEXT NOT NULL, + last_renewed_ts BIGINT NOT NULL +); +CREATE UNIQUE INDEX worker_locks_key ON worker_locks (lock_name, lock_key); +CREATE TABLE federation_inbound_events_staging ( + origin TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + received_ts BIGINT NOT NULL, + event_json TEXT NOT NULL, + internal_metadata TEXT NOT NULL +); +CREATE INDEX federation_inbound_events_staging_room ON federation_inbound_events_staging(room_id, received_ts); +CREATE UNIQUE INDEX federation_inbound_events_staging_instance_event ON federation_inbound_events_staging(origin, event_id); +CREATE TABLE insertion_event_edges( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + insertion_prev_event_id TEXT NOT NULL +); +CREATE INDEX insertion_event_edges_insertion_room_id ON insertion_event_edges(room_id); +CREATE INDEX insertion_event_edges_insertion_prev_event_id ON insertion_event_edges(insertion_prev_event_id); +CREATE TABLE insertion_event_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL +); +CREATE UNIQUE INDEX insertion_event_extremities_event_id ON insertion_event_extremities(event_id); +CREATE INDEX insertion_event_extremities_room_id ON insertion_event_extremities(room_id); +CREATE TABLE registration_tokens( + token TEXT NOT NULL, -- The token that can be used for authentication. + uses_allowed INT, -- The total number of times this token can be used. NULL if no limit. + pending INT NOT NULL, -- The number of in progress registrations using this token. + completed INT NOT NULL, -- The number of times this token has been used to complete a registration. + expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire. + UNIQUE (token) +); +CREATE TABLE sessions( + session_type TEXT NOT NULL, -- The unique key for this type of session. + session_id TEXT NOT NULL, -- The session ID passed to the client. + value TEXT NOT NULL, -- A JSON dictionary to persist. + expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds). + UNIQUE (session_type, session_id) +); +CREATE TABLE insertion_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + next_batch_id TEXT NOT NULL +); +CREATE UNIQUE INDEX insertion_events_event_id ON insertion_events(event_id); +CREATE INDEX insertion_events_next_batch_id ON insertion_events(next_batch_id); +CREATE TABLE batch_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + batch_id TEXT NOT NULL +); +CREATE UNIQUE INDEX batch_events_event_id ON batch_events(event_id); +CREATE INDEX batch_events_batch_id ON batch_events(batch_id); +CREATE INDEX insertion_event_edges_event_id ON insertion_event_edges(event_id); +CREATE TABLE device_auth_providers ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + auth_provider_id TEXT NOT NULL, + auth_provider_session_id TEXT NOT NULL +); +CREATE INDEX device_auth_providers_devices + ON device_auth_providers (user_id, device_id); +CREATE INDEX device_auth_providers_sessions + ON device_auth_providers (auth_provider_id, auth_provider_session_id); +CREATE INDEX refresh_tokens_next_token_id + ON refresh_tokens(next_token_id) + WHERE next_token_id IS NOT NULL; +CREATE TABLE partial_state_rooms ( + room_id TEXT PRIMARY KEY, + FOREIGN KEY(room_id) REFERENCES rooms(room_id) +); +CREATE TABLE partial_state_rooms_servers ( + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + server_name TEXT NOT NULL, + UNIQUE(room_id, server_name) +); +CREATE TABLE partial_state_events ( + -- the room_id is denormalised for efficient indexing (the canonical source is `events`) + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + event_id TEXT NOT NULL REFERENCES events(event_id), + UNIQUE(event_id) +); +CREATE INDEX partial_state_events_room_id_idx + ON partial_state_events (room_id); +CREATE TRIGGER partial_state_events_bad_room_id + BEFORE INSERT ON partial_state_events + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ); + END; +CREATE TABLE device_lists_changes_in_room ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- This initially matches `device_lists_stream.stream_id`. Note that we + -- delete older values from `device_lists_stream`, so we can't use a foreign + -- constraint here. + -- + -- The table will contain rows with the same `stream_id` but different + -- `room_id`, as for each device update we store a row per room the user is + -- joined to. Therefore `(stream_id, room_id)` gives a unique index. + stream_id BIGINT NOT NULL, + + -- We have a background process which goes through this table and converts + -- entries into rows in `device_lists_outbound_pokes`. Once we have processed + -- a row, we mark it as such by setting `converted_to_destinations=TRUE`. + converted_to_destinations BOOLEAN NOT NULL, + opentracing_context TEXT +); +CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id); +CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations; +CREATE TABLE IF NOT EXISTS "event_edges" ( + event_id TEXT NOT NULL, + prev_event_id TEXT NOT NULL, + room_id TEXT NULL, + is_state BOOL NOT NULL DEFAULT 0, + FOREIGN KEY(event_id) REFERENCES events(event_id) +); +CREATE UNIQUE INDEX event_edges_event_id_prev_event_id_idx + ON event_edges (event_id, prev_event_id); +CREATE INDEX ev_edges_prev_id ON event_edges (prev_event_id); +CREATE TABLE event_push_summary_last_receipt_stream_id ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); +CREATE TABLE IF NOT EXISTS "application_services_state" ( + as_id TEXT PRIMARY KEY NOT NULL, + state VARCHAR(5), + read_receipt_stream_id BIGINT, + presence_stream_id BIGINT, + to_device_stream_id BIGINT, + device_list_stream_id BIGINT +); +CREATE TABLE IF NOT EXISTS "receipts_linearized" ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + thread_id TEXT, + event_stream_ordering BIGINT, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); +CREATE TABLE IF NOT EXISTS "receipts_graph" ( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + thread_id TEXT, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); +CREATE INDEX redactions_have_censored_ts ON redactions (received_ts) WHERE NOT have_censored; +CREATE INDEX room_memberships_user_room_forgotten ON room_memberships (user_id, room_id) WHERE forgotten = 1; +CREATE INDEX users_have_local_media ON local_media_repository (user_id, created_ts) ; +CREATE UNIQUE INDEX e2e_cross_signing_keys_stream_idx ON e2e_cross_signing_keys (stream_id) ; +CREATE INDEX user_external_ids_user_id_idx ON user_external_ids (user_id) ; +CREATE INDEX presence_stream_state_not_offline_idx ON presence_stream (state) WHERE state != 'offline'; +CREATE UNIQUE INDEX event_push_summary_unique_index ON event_push_summary (user_id, room_id) ; +CREATE UNIQUE INDEX event_push_summary_unique_index2 ON event_push_summary (user_id, room_id, thread_id) ; +CREATE UNIQUE INDEX receipts_graph_unique_index ON receipts_graph (room_id, receipt_type, user_id) WHERE thread_id IS NULL; +CREATE UNIQUE INDEX receipts_linearized_unique_index ON receipts_linearized (room_id, receipt_type, user_id) WHERE thread_id IS NULL; +CREATE INDEX event_push_actions_stream_highlight_index ON event_push_actions (highlight, stream_ordering) WHERE highlight=0; +CREATE INDEX current_state_events_member_index ON current_state_events (state_key) WHERE type='m.room.member'; +CREATE INDEX event_contains_url_index ON events (room_id, topological_ordering, stream_ordering) WHERE contains_url = true AND outlier = false; +CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering) WHERE highlight=1; +CREATE INDEX local_media_repository_url_idx ON local_media_repository (created_ts) WHERE url_cache IS NOT NULL; +INSERT INTO appservice_stream_position VALUES('X',0); +INSERT INTO federation_stream_position VALUES('federation',-1,'master'); +INSERT INTO federation_stream_position VALUES('events',-1,'master'); +INSERT INTO event_push_summary_stream_ordering VALUES('X',0); +INSERT INTO user_directory_stream_pos VALUES('X',1); +INSERT INTO stats_incremental_position VALUES('X',1); +INSERT INTO event_push_summary_last_receipt_stream_id VALUES('X',0); diff --git a/synapse/storage/schema/state/full_schemas/72/full.sql.postgres b/synapse/storage/schema/state/full_schemas/72/full.sql.postgres new file mode 100644 index 0000000000..263ade761e --- /dev/null +++ b/synapse/storage/schema/state/full_schemas/72/full.sql.postgres @@ -0,0 +1,30 @@ +CREATE TABLE state_group_edges ( + state_group bigint NOT NULL, + prev_state_group bigint NOT NULL +); +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; +CREATE TABLE state_groups ( + id bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL +); +CREATE TABLE state_groups_state ( + state_group bigint NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + event_id text NOT NULL +); +ALTER TABLE ONLY state_groups_state ALTER COLUMN state_group SET (n_distinct=-0.02); +ALTER TABLE ONLY state_groups + ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); +CREATE UNIQUE INDEX state_group_edges_unique_idx ON state_group_edges USING btree (state_group, prev_state_group); +CREATE INDEX state_groups_room_id_idx ON state_groups USING btree (room_id); +CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); +SELECT pg_catalog.setval('state_group_id_seq', 1, false); diff --git a/synapse/storage/schema/state/full_schemas/72/full.sql.sqlite b/synapse/storage/schema/state/full_schemas/72/full.sql.sqlite new file mode 100644 index 0000000000..dda060b638 --- /dev/null +++ b/synapse/storage/schema/state/full_schemas/72/full.sql.sqlite @@ -0,0 +1,20 @@ +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); +CREATE INDEX state_groups_room_id_idx ON state_groups (room_id) ; +CREATE UNIQUE INDEX state_group_edges_unique_idx ON state_group_edges (state_group, prev_state_group) ; -- cgit 1.5.1 From 2fae1a3f7862bf38cd0b52dfd3ea3ae76794d2b7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 26 Sep 2022 14:28:12 -0400 Subject: Improve tests for get_unread_push_actions_for_user_in_range_*. (#13893) * Adds a docstring. * Reduces a small amount of duplicated code. * Improves tests. --- changelog.d/13893.feature | 1 + .../storage/databases/main/event_push_actions.py | 38 ++++++---- tests/storage/test_event_push_actions.py | 88 ++++++++++++++++++---- 3 files changed, 97 insertions(+), 30 deletions(-) create mode 100644 changelog.d/13893.feature (limited to 'synapse') diff --git a/changelog.d/13893.feature b/changelog.d/13893.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13893.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 6b8668d2dc..f4cdc2e399 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -559,7 +559,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> List[Tuple[str, int]]: + ) -> Dict[str, int]: + """ + Generate a map of room ID to the latest stream ordering that has been + read by the given user. + + Args: + txn: + user_id: The user to fetch receipts for. + + Returns: + A map of room ID to stream ordering for all rooms the user has a receipt in. + """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", @@ -580,7 +591,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas args.extend((user_id,)) txn.execute(sql, args) - return cast(List[Tuple[str, int]], txn.fetchall()) + return { + room_id: latest_stream_ordering + for room_id, latest_stream_ordering in txn.fetchall() + } async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -605,12 +619,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn( @@ -679,12 +691,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 08c74b93e3..473c965e19 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin @@ -22,8 +24,6 @@ from synapse.util import Clock from tests.unittest import HomeserverTestCase -USER_ID = "@user:example.com" - class EventPushActionsStoreTestCase(HomeserverTestCase): servlets = [ @@ -38,21 +38,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): assert persist_events_store is not None self.persist_events_store = persist_events_store - def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None: - self.get_success( - self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 - ) - ) + def _create_users_and_room(self) -> Tuple[str, str, str, str, str]: + """ + Creates two users and a shared room. - def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None: - self.get_success( - self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 - ) - ) - - def test_count_aggregation(self) -> None: + Returns: + Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID). + """ # Create a user to receive notifications and send receipts. user_id = self.register_user("user1235", "pass") token = self.login("user1235", "pass") @@ -65,6 +57,70 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): room_id = self.helper.create_room_as(user_id, tok=token) self.helper.join(room_id, other_id, tok=other_token) + return user_id, token, other_id, other_token, room_id + + def test_get_unread_push_actions_for_user_in_range(self) -> None: + """Test getting unread push actions for HTTP and email pushers.""" + user_id, token, _, other_token, room_id = self._create_users_and_room() + + # Create two events, one of which is a highlight. + self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": "msg"}, + tok=other_token, + ) + event_id = self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": user_id}, + tok=other_token, + )["event_id"] + + # Fetch unread actions for HTTP pushers. + http_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_http( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual(2, len(http_actions)) + + # Fetch unread actions for email pushers. + email_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_email( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual(2, len(email_actions)) + + # Send a receipt, which should clear any actions. + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=None, + data={}, + ) + ) + http_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_http( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual([], http_actions) + email_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_email( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual([], email_actions) + + def test_count_aggregation(self) -> None: + # Create a user to receive notifications and send receipts. + user_id, token, _, other_token, room_id = self._create_users_and_room() + last_event_id: str def _assert_counts(noitf_count: int, highlight_count: int) -> None: -- cgit 1.5.1 From d6b85a2a7dea2737e69d67842c2246975ec64bce Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 26 Sep 2022 23:07:02 +0100 Subject: Complement image: propagate SIGTERM to all workers (#13914) This should mean that logs from worker processes are flushed before shutdown. When a test completes, Complement stops the docker container, which means that synapse will receive a SIGTERM. Currently, the `complement_fork_starter` exits immediately (without notifying the worker processes), which means that the workers never get a chance to flush their logs before the whole container is vaped. We can fix this by propagating the SIGTERM to the children. --- changelog.d/13914.misc | 1 + synapse/app/complement_fork_starter.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13914.misc (limited to 'synapse') diff --git a/changelog.d/13914.misc b/changelog.d/13914.misc new file mode 100644 index 0000000000..c29bc25d38 --- /dev/null +++ b/changelog.d/13914.misc @@ -0,0 +1 @@ +Complement image: propagate SIGTERM to all workers. diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 89eb07df27..b22f315453 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -51,11 +51,18 @@ import argparse import importlib import itertools import multiprocessing +import os +import signal import sys -from typing import Any, Callable, List +from types import FrameType +from typing import Any, Callable, List, Optional from twisted.internet.main import installReactor +# a list of the original signal handlers, before we installed our custom ones. +# We restore these in our child processes. +_original_signal_handlers: dict[int, Any] = {} + class ProxiedReactor: """ @@ -105,6 +112,11 @@ def _worker_entrypoint( sys.argv = args + # reset the custom signal handlers that we installed, so that the children start + # from a clean slate. + for sig, handler in _original_signal_handlers.items(): + signal.signal(sig, handler) + from twisted.internet.epollreactor import EPollReactor proxy_reactor._install_real_reactor(EPollReactor()) @@ -167,13 +179,29 @@ def main() -> None: update_proc.join() print("===== PREPARED DATABASE =====", file=sys.stderr) + processes: List[multiprocessing.Process] = [] + + # Install signal handlers to propagate signals to all our children, so that they + # shut down cleanly. This also inhibits our own exit, but that's good: we want to + # wait until the children have exited. + def handle_signal(signum: int, frame: Optional[FrameType]) -> None: + print( + f"complement_fork_starter: Caught signal {signum}. Stopping children.", + file=sys.stderr, + ) + for p in processes: + if p.pid: + os.kill(p.pid, signum) + + for sig in (signal.SIGINT, signal.SIGTERM): + _original_signal_handlers[sig] = signal.signal(sig, handle_signal) + # At this point, we've imported all the main entrypoints for all the workers. # Now we basically just fork() out to create the workers we need. # Because we're using fork(), all the workers get a clone of this launcher's # memory space and don't need to repeat the work of loading the code! # Instead of using fork() directly, we use the multiprocessing library, # which uses fork() on Unix platforms. - processes = [] for (func, worker_args) in zip(worker_functions, args_by_worker): process = multiprocessing.Process( target=_worker_entrypoint, args=(func, proxy_reactor, worker_args) -- cgit 1.5.1 From 85e161631a2ca7d495b619456221311ec1c93096 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 27 Sep 2022 11:17:23 +0100 Subject: Faster room joins: Fix spurious error when joining a room (#13872) During a `lazy_load_members` `/sync`, we look through auth events in rooms with partial state to find prior membership events. When such a membership is not found, an error is logged. Since the first join event for a user never has a prior membership event to cite, the error would always be logged when one appeared in the room timeline. Avoid logging errors for such events. Introduced in #13477. Signed-off-by: Sean Quah --- changelog.d/13872.bugfix | 1 + synapse/handlers/sync.py | 22 +++++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13872.bugfix (limited to 'synapse') diff --git a/changelog.d/13872.bugfix b/changelog.d/13872.bugfix new file mode 100644 index 0000000000..67d3d9e643 --- /dev/null +++ b/changelog.d/13872.bugfix @@ -0,0 +1 @@ +Faster room joins: Fix a bug introduced in 1.66.0 where an error would be logged when syncing after joining a room. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5293fa4d0e..e75fc6b947 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1191,7 +1191,9 @@ class SyncHandler: room_id: The partial state room to find the remaining memberships for. members_to_fetch: The memberships to find. events_with_membership_auth: A mapping from user IDs to events whose auth - events are known to contain their membership. + events would contain their prior membership, if one exists. + Note that join events will not cite a prior membership if a user has + never been in a room before. found_state_ids: A dict from (type, state_key) -> state_event_id, containing memberships that have been previously found. Entries in `members_to_fetch` that have a membership in `found_state_ids` are @@ -1201,6 +1203,10 @@ class SyncHandler: A dict from ("m.room.member", state_key) -> state_event_id, containing the memberships missing from `found_state_ids`. + When `events_with_membership_auth` contains a join event for a given user + which does not cite a prior membership, no membership is returned for that + user. + Raises: KeyError: if `events_with_membership_auth` does not have an entry for a missing membership. Memberships in `found_state_ids` do not need an @@ -1218,8 +1224,18 @@ class SyncHandler: if (EventTypes.Member, member) in found_state_ids: continue - missing_members.add(member) event_with_membership_auth = events_with_membership_auth[member] + is_join = ( + event_with_membership_auth.is_state() + and event_with_membership_auth.type == EventTypes.Member + and event_with_membership_auth.state_key == member + and event_with_membership_auth.content.get("membership") + == Membership.JOIN + ) + if not is_join: + # The event must include the desired membership as an auth event, unless + # it's the first join event for a given user. + missing_members.add(member) auth_event_ids.update(event_with_membership_auth.auth_event_ids()) auth_events = await self.store.get_events(auth_event_ids) @@ -1243,7 +1259,7 @@ class SyncHandler: auth_event.type == EventTypes.Member and auth_event.state_key == member ): - missing_members.remove(member) + missing_members.discard(member) additional_state_ids[ (EventTypes.Member, member) ] = auth_event.event_id -- cgit 1.5.1 From e8318a433356413648bd180dcfc69c29ca319fc6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 27 Sep 2022 13:01:08 +0100 Subject: Handle the case of remote users leaving a partial join room for device lists (#13885) --- changelog.d/13885.misc | 1 + synapse/app/admin_cmd.py | 2 +- synapse/storage/controllers/persist_events.py | 71 --------------------------- synapse/storage/databases/main/__init__.py | 2 +- synapse/storage/databases/main/devices.py | 64 ++++++++++++++++++------ synapse/storage/databases/main/events.py | 6 +++ synapse/storage/databases/main/roommember.py | 46 +++++++++-------- 7 files changed, 85 insertions(+), 107 deletions(-) create mode 100644 changelog.d/13885.misc (limited to 'synapse') diff --git a/changelog.d/13885.misc b/changelog.d/13885.misc new file mode 100644 index 0000000000..bc76b862df --- /dev/null +++ b/changelog.d/13885.misc @@ -0,0 +1 @@ +Correctly handle a race with device lists when a remote user leaves during a partial join. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 8a583d3ec6..3c8c00ea5b 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -53,9 +53,9 @@ logger = logging.getLogger("synapse.app.admin_cmd") class AdminCmdSlavedStore( SlavedFilteringStore, - SlavedDeviceStore, SlavedPushRuleStore, SlavedEventStore, + SlavedDeviceStore, TagsWorkerStore, DeviceInboxWorkerStore, AccountDataWorkerStore, diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 501dbbc990..709cb792ed 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -598,11 +598,6 @@ class EventsPersistenceStorageController: # room state_delta_for_room: Dict[str, DeltaState] = {} - # Set of remote users which were in rooms the server has left or who may - # have left rooms the server is in. We should check if we still share any - # rooms and if not we mark their device lists as stale. - potentially_left_users: Set[str] = set() - if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. @@ -716,8 +711,6 @@ class EventsPersistenceStorageController: room_id, ev_ctx_rm, delta, - current_state, - potentially_left_users, ) if not is_still_joined: logger.info("Server no longer in room %s", room_id) @@ -725,20 +718,6 @@ class EventsPersistenceStorageController: current_state = {} delta.no_longer_in_room = True - # Add all remote users that might have left rooms. - potentially_left_users.update( - user_id - for event_type, user_id in delta.to_delete - if event_type == EventTypes.Member - and not self.is_mine_id(user_id) - ) - potentially_left_users.update( - user_id - for event_type, user_id in delta.to_insert.keys() - if event_type == EventTypes.Member - and not self.is_mine_id(user_id) - ) - state_delta_for_room[room_id] = delta await self.persist_events_store._persist_events_and_state_updates( @@ -749,8 +728,6 @@ class EventsPersistenceStorageController: inhibit_local_membership_updates=backfilled, ) - await self._handle_potentially_left_users(potentially_left_users) - return replaced_events async def _calculate_new_extremities( @@ -1126,8 +1103,6 @@ class EventsPersistenceStorageController: room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, - current_state: Optional[StateMap[str]], - potentially_left_users: Set[str], ) -> bool: """Check if the server will still be joined after the given events have been persised. @@ -1137,11 +1112,6 @@ class EventsPersistenceStorageController: ev_ctx_rm delta: The delta of current state between what is in the database and what the new current state will be. - current_state: The new current state if it already been calculated, - otherwise None. - potentially_left_users: If the server has left the room, then joined - remote users will be added to this set to indicate that the - server may no longer be sharing a room with them. """ if not any( @@ -1195,45 +1165,4 @@ class EventsPersistenceStorageController: ): return True - # The server will leave the room, so we go and find out which remote - # users will still be joined when we leave. - if current_state is None: - current_state = await self.main_store.get_partial_current_state_ids(room_id) - current_state = dict(current_state) - for key in delta.to_delete: - current_state.pop(key, None) - - current_state.update(delta.to_insert) - - remote_event_ids = [ - event_id - for ( - typ, - state_key, - ), event_id in current_state.items() - if typ == EventTypes.Member and not self.is_mine_id(state_key) - ] - members = await self.main_store.get_membership_from_event_ids(remote_event_ids) - potentially_left_users.update( - member.user_id - for member in members.values() - if member and member.membership == Membership.JOIN - ) - return False - - async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None: - """Given a set of remote users check if the server still shares a room with - them. If not then mark those users' device cache as stale. - """ - - if not user_ids: - return - - joined_users = await self.main_store.get_users_server_still_shares_room_with( - user_ids - ) - left_users = user_ids - joined_users - - for user_id in left_users: - await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 4dccbb732a..0843f10340 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -83,6 +83,7 @@ logger = logging.getLogger(__name__) class DataStore( EventsBackgroundUpdatesStore, + DeviceStore, RoomMemberStore, RoomStore, RoomBatchStore, @@ -114,7 +115,6 @@ class DataStore( StreamWorkerStore, OpenIdStore, ClientIpWorkerStore, - DeviceStore, DeviceInboxStore, UserDirectoryStore, UserErasureStore, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 5d700ca6c3..1151fb0cc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -47,6 +47,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder @@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" -class DeviceWorkerStore(EndToEndKeyWorkerStore): +class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def __init__( self, database: DatabasePool, @@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): desc="mark_remote_user_device_cache_as_valid", ) + async def handle_potentially_left_users(self, user_ids: Set[str]) -> None: + """Given a set of remote users check if the server still shares a room with + them. If not then mark those users' device cache as stale. + """ + + if not user_ids: + return + + await self.db_pool.runInteraction( + "_handle_potentially_left_users", + self.handle_potentially_left_users_txn, + user_ids, + ) + + def handle_potentially_left_users_txn( + self, + txn: LoggingTransaction, + user_ids: Set[str], + ) -> None: + """Given a set of remote users check if the server still shares a room with + them. If not then mark those users' device cache as stale. + """ + + if not user_ids: + return + + joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids) + left_users = user_ids - joined_users + + for user_id in left_users: + self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id) + async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user.""" - def _mark_remote_user_device_list_as_unsubscribed_txn( - txn: LoggingTransaction, - ) -> None: - self.db_pool.simple_delete_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - ) - self._invalidate_cache_and_stream( - txn, self.get_device_list_last_stream_id_for_remote, (user_id,) - ) - await self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", - _mark_remote_user_device_list_as_unsubscribed_txn, + self.mark_remote_user_device_list_as_unsubscribed_txn, + user_id, + ) + + def mark_remote_user_device_list_as_unsubscribed_txn( + self, + txn: LoggingTransaction, + user_id: str, + ) -> None: + self.db_pool.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) async def get_dehydrated_device( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2e156a4a11..b59eb7478b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1202,6 +1202,12 @@ class PersistEventsStore: txn, room_id, members_changed ) + # Check if any of the remote membership changes requires us to + # unsubscribe from their device lists. + self.store.handle_potentially_left_users_txn( + txn, {m for m in members_changed if not self.hs.is_mine_id(m)} + ) + def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state events. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index a8d224602a..8ada3cdac3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -662,31 +662,37 @@ class RoomMemberWorkerStore(EventsWorkerStore): if not user_ids: return set() - def _get_users_server_still_shares_room_with_txn( - txn: LoggingTransaction, - ) -> Set[str]: - sql = """ - SELECT state_key FROM current_state_events - WHERE - type = 'm.room.member' - AND membership = 'join' - AND %s - GROUP BY state_key - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "state_key", user_ids - ) + return await self.db_pool.runInteraction( + "get_users_server_still_shares_room_with", + self.get_users_server_still_shares_room_with_txn, + user_ids, + ) - txn.execute(sql % (clause,), args) + def get_users_server_still_shares_room_with_txn( + self, + txn: LoggingTransaction, + user_ids: Collection[str], + ) -> Set[str]: + if not user_ids: + return set() - return {row[0] for row in txn} + sql = """ + SELECT state_key FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND %s + GROUP BY state_key + """ - return await self.db_pool.runInteraction( - "get_users_server_still_shares_room_with", - _get_users_server_still_shares_room_with_txn, + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids ) + txn.execute(sql % (clause,), args) + + return {row[0] for row in txn} + @cancellable async def get_rooms_for_user( self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None -- cgit 1.5.1 From 50c92f3a692a745d2b42f9731af4da493fa27715 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 27 Sep 2022 15:38:14 +0200 Subject: Carry IdP Session IDs through user-mapping sessions. (#13839) Since #11482, we're saving sessions IDs from upstream IdPs, but we've been losing them when the user goes through a user mapping session on account registration. --- changelog.d/13839.misc | 1 + synapse/handlers/sso.py | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 changelog.d/13839.misc (limited to 'synapse') diff --git a/changelog.d/13839.misc b/changelog.d/13839.misc new file mode 100644 index 0000000000..549872c90f --- /dev/null +++ b/changelog.d/13839.misc @@ -0,0 +1 @@ +Carry IdP Session IDs through user-mapping sessions. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 6bc1cbd787..e035677b8a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -147,6 +147,9 @@ class UsernameMappingSession: # A unique identifier for this SSO provider, e.g. "oidc" or "saml". auth_provider_id: str + # An optional session ID from the IdP. + auth_provider_session_id: Optional[str] + # user ID on the IdP server remote_user_id: str @@ -464,6 +467,7 @@ class SsoHandler: client_redirect_url, next_step_url, extra_login_attributes, + auth_provider_session_id, ) user_id = await self._register_mapped_user( @@ -585,6 +589,7 @@ class SsoHandler: client_redirect_url: str, next_step_url: bytes, extra_login_attributes: Optional[JsonDict], + auth_provider_session_id: Optional[str], ) -> NoReturn: """Creates a UsernameMappingSession and redirects the browser @@ -607,6 +612,8 @@ class SsoHandler: extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. + auth_provider_session_id: An optional session ID from the IdP. + Raises: RedirectException """ @@ -615,6 +622,7 @@ class SsoHandler: now = self._clock.time_msec() session = UsernameMappingSession( auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, remote_user_id=remote_user_id, display_name=attributes.display_name, emails=attributes.emails, @@ -968,6 +976,7 @@ class SsoHandler: session.client_redirect_url, session.extra_login_attributes, new_user=True, + auth_provider_session_id=session.auth_provider_session_id, ) def _expire_old_sessions(self) -> None: -- cgit 1.5.1 From 299b00d968ee23ba4e4806dd7c4fa97c7fcfb6f5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 27 Sep 2022 15:17:41 +0100 Subject: Prioritize outbound to-device over device list updates (#13922) Otherwise device list changes for large accounts can temporarily delay to-device messages. --- changelog.d/13922.bugfix | 1 + synapse/federation/sender/per_destination_queue.py | 29 ++++++++++++---------- 2 files changed, 17 insertions(+), 13 deletions(-) create mode 100644 changelog.d/13922.bugfix (limited to 'synapse') diff --git a/changelog.d/13922.bugfix b/changelog.d/13922.bugfix new file mode 100644 index 0000000000..7269d28dee --- /dev/null +++ b/changelog.d/13922.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where device updates could cause delays sending out to-device messages over federation. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 41d8b937af..084c45a95c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -646,29 +646,32 @@ class _TransactionQueueManager: # We start by fetching device related EDUs, i.e device updates and to # device messages. We have to keep 2 free slots for presence and rr_edus. - limit = MAX_EDUS_PER_TRANSACTION - 2 - - device_update_edus, dev_list_id = await self.queue._get_device_update_edus( - limit - ) - - if device_update_edus: - self._device_list_id = dev_list_id - else: - self.queue._last_device_list_stream_id = dev_list_id - - limit -= len(device_update_edus) + device_edu_limit = MAX_EDUS_PER_TRANSACTION - 2 + # We prioritize to-device messages so that existing encryption channels + # work. We also keep a few slots spare (by reducing the limit) so that + # we can still trickle out some device list updates. ( to_device_edus, device_stream_id, - ) = await self.queue._get_to_device_message_edus(limit) + ) = await self.queue._get_to_device_message_edus(device_edu_limit - 10) if to_device_edus: self._device_stream_id = device_stream_id else: self.queue._last_device_stream_id = device_stream_id + device_edu_limit -= len(to_device_edus) + + device_update_edus, dev_list_id = await self.queue._get_device_update_edus( + device_edu_limit + ) + + if device_update_edus: + self._device_list_id = dev_list_id + else: + self.queue._last_device_list_stream_id = dev_list_id + pending_edus = device_update_edus + to_device_edus # Now add the read receipt EDU. -- cgit 1.5.1 From 87fe9db4675e510ea9c0234429b4773341c4e86d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 27 Sep 2022 10:47:34 -0400 Subject: Support the stable dir parameter for /relations. (#13920) Since MSC3715 has passed FCP, the stable parameter can be used. This currently falls back to the unstable parameter if the stable parameter is not provided (and MSC3715 support is enabled in the configuration). --- changelog.d/13920.feature | 1 + synapse/rest/client/relations.py | 24 +++++++++++++++--------- tests/rest/client/test_relations.py | 6 ++---- 3 files changed, 18 insertions(+), 13 deletions(-) create mode 100644 changelog.d/13920.feature (limited to 'synapse') diff --git a/changelog.d/13920.feature b/changelog.d/13920.feature new file mode 100644 index 0000000000..aee702bcd2 --- /dev/null +++ b/changelog.d/13920.feature @@ -0,0 +1 @@ +Support a `dir` parameter on the `/relations` endpoint per [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index ce97080013..205c556f64 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -56,15 +56,21 @@ class RelationPaginationServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=5) - if self._msc3715_enabled: - direction = parse_string( - request, - "org.matrix.msc3715.dir", - default="b", - allowed_values=["f", "b"], - ) - else: - direction = "b" + # Fetch the direction parameter, if provided. + # + # TODO Use PaginationConfig.from_request when the unstable parameter is + # no longer needed. + direction = parse_string(request, "dir", allowed_values=["f", "b"]) + if direction is None: + if self._msc3715_enabled: + direction = parse_string( + request, + "org.matrix.msc3715.dir", + default="b", + allowed_values=["f", "b"], + ) + else: + direction = "b" from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d33e34d829..fef3b72d76 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -728,7 +728,6 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationPaginationTestCase(BaseRelationsTestCase): - @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -771,7 +770,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/_matrix/client/v1/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + f"/{self.parent_id}?limit=1&dir=f", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -788,7 +787,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -838,7 +836,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?dir=f&limit=3{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) -- cgit 1.5.1 From f5aaa55e2702af3cac1e195bf5d703970c24ff29 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 27 Sep 2022 17:26:35 +0100 Subject: Add new columns tracking when we partial-joined (#13892) --- changelog.d/13892.feature | 1 + synapse/handlers/federation.py | 14 +++++- synapse/storage/databases/main/room.py | 52 +++++++++++++++++++++- .../main/delta/73/04partial_join_details.sql | 23 ++++++++++ 4 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13892.feature create mode 100644 synapse/storage/schema/main/delta/73/04partial_join_details.sql (limited to 'synapse') diff --git a/changelog.d/13892.feature b/changelog.d/13892.feature new file mode 100644 index 0000000000..df3f576536 --- /dev/null +++ b/changelog.d/13892.feature @@ -0,0 +1 @@ +Faster remote room joins: record _when_ we first partial-join to a room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e1a4265a64..74580f60df 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -581,7 +581,11 @@ class FederationHandler: # Mark the room as having partial state. # The background process is responsible for unmarking this flag, # even if the join fails. - await self.store.store_partial_state_room(room_id, ret.servers_in_room) + await self.store.store_partial_state_room( + room_id=room_id, + servers=ret.servers_in_room, + device_lists_stream_id=self.store.get_device_stream_token(), + ) try: max_stream_id = ( @@ -606,6 +610,14 @@ class FederationHandler: room_id, ) raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0) + else: + # Record the join event id for future use (when we finish the full + # join). We have to do this after persisting the event to keep foreign + # key constraints intact. + if ret.partial_state: + await self.store.write_partial_state_rooms_join_event_id( + room_id, event.event_id + ) finally: # Always kick off the background process that asynchronously fetches # state for the room. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 5dd116d766..064c332fb7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1777,28 +1777,46 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self, room_id: str, servers: Collection[str], + device_lists_stream_id: int, ) -> None: - """Mark the given room as containing events with partial state + """Mark the given room as containing events with partial state. + + We also store additional data that describes _when_ we first partial-joined this + room, which helps us to keep other homeservers in sync when we finally fully + join this room. + + We do not include a `join_event_id` here---we need to wait for the join event + to be persisted first. Args: room_id: the ID of the room servers: other servers known to be in the room + device_lists_stream_id: the device_lists stream ID at the time when we first + joined the room. """ await self.db_pool.runInteraction( "store_partial_state_room", self._store_partial_state_room_txn, room_id, servers, + device_lists_stream_id, ) def _store_partial_state_room_txn( - self, txn: LoggingTransaction, room_id: str, servers: Collection[str] + self, + txn: LoggingTransaction, + room_id: str, + servers: Collection[str], + device_lists_stream_id: int, ) -> None: DatabasePool.simple_insert_txn( txn, table="partial_state_rooms", values={ "room_id": room_id, + "device_lists_stream_id": device_lists_stream_id, + # To be updated later once the join event is persisted. + "join_event_id": None, }, ) DatabasePool.simple_insert_many_txn( @@ -1809,6 +1827,36 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + async def write_partial_state_rooms_join_event_id( + self, + room_id: str, + join_event_id: str, + ) -> None: + """Record the join event which resulted from a partial join. + + We do this separately to `store_partial_state_room` because we need to wait for + the join event to be persisted. Otherwise we violate a foreign key constraint. + """ + await self.db_pool.runInteraction( + "write_partial_state_rooms_join_event_id", + self._write_partial_state_rooms_join_event_id, + room_id, + join_event_id, + ) + + def _write_partial_state_rooms_join_event_id( + self, + txn: LoggingTransaction, + room_id: str, + join_event_id: str, + ) -> None: + DatabasePool.simple_update_txn( + txn, + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + updatevalues={"join_event_id": join_event_id}, + ) + async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ) -> None: diff --git a/synapse/storage/schema/main/delta/73/04partial_join_details.sql b/synapse/storage/schema/main/delta/73/04partial_join_details.sql new file mode 100644 index 0000000000..5fb2bfe1a2 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/04partial_join_details.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- To ensure we correctly notify other homeservers about device list changes from our +-- users after a partial join transitions to a full join, we need to know when we began +-- the partial join. For now it's sufficient to know the device_list stream_id at the +-- time of the partial join, and the join event created for us during a partial join. +-- +-- Both columns are backwards compatible. +ALTER TABLE partial_state_rooms ADD COLUMN device_lists_stream_id BIGINT NOT NULL DEFAULT 0; +ALTER TABLE partial_state_rooms ADD COLUMN join_event_id TEXT REFERENCES events(event_id); -- cgit 1.5.1 From 29269d9d3f3419a3d92cdd80dae4a37e2d99a395 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 27 Sep 2022 15:55:43 -0500 Subject: Fix `have_seen_event` cache not being invalidated (#13863) Fix https://github.com/matrix-org/synapse/issues/13856 Fix https://github.com/matrix-org/synapse/issues/13865 > Discovered while trying to make Synapse fast enough for [this MSC2716 test for importing many batches](https://github.com/matrix-org/complement/pull/214#discussion_r741678240). As an example, disabling the `have_seen_event` cache saves 10 seconds for each `/messages` request in that MSC2716 Complement test because we're not making as many federation requests for `/state` (speeding up `have_seen_event` itself is related to https://github.com/matrix-org/synapse/issues/13625) > > But this will also make `/messages` faster in general so we can include it in the [faster `/messages` milestone](https://github.com/matrix-org/synapse/milestone/11). > > *-- https://github.com/matrix-org/synapse/issues/13856* ### The problem `_invalidate_caches_for_event` doesn't run in monolith mode which means we never even tried to clear the `have_seen_event` and other caches. And even in worker mode, it only runs on the workers, not the master (AFAICT). Additionally there was bug with the key being wrong so `_invalidate_caches_for_event` never invalidates the `have_seen_event` cache even when it does run. Because we were using the `@cachedList` wrong, it was putting items in the cache under keys like `((room_id, event_id),)` with a `set` in a `set` (ex. `(('!TnCIJPKzdQdUlIyXdQ:test', '$Iu0eqEBN7qcyF1S9B3oNB3I91v2o5YOgRNPwi_78s-k'),)`) and we we're trying to invalidate with just `(room_id, event_id)` which did nothing. --- changelog.d/13863.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 40 +++--- synapse/util/caches/descriptors.py | 6 + tests/storage/databases/main/test_events_worker.py | 152 ++++++++++++++------- tests/util/caches/test_descriptors.py | 33 ++++- 5 files changed, 165 insertions(+), 67 deletions(-) create mode 100644 changelog.d/13863.bugfix (limited to 'synapse') diff --git a/changelog.d/13863.bugfix b/changelog.d/13863.bugfix new file mode 100644 index 0000000000..74264a4fab --- /dev/null +++ b/changelog.d/13863.bugfix @@ -0,0 +1 @@ +Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 52914febf9..7cdc9fe98f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore): # the batches as big as possible. results: Set[str] = set() - for chunk in batch_iter(event_ids, 500): - r = await self._have_seen_events_dict( - [(room_id, event_id) for event_id in chunk] + for event_ids_chunk in batch_iter(event_ids, 500): + events_seen_dict = await self._have_seen_events_dict( + room_id, event_ids_chunk + ) + results.update( + eid for (eid, have_event) in events_seen_dict.items() if have_event ) - results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) return results - @cachedList(cached_method_name="have_seen_event", list_name="keys") + @cachedList(cached_method_name="have_seen_event", list_name="event_ids") async def _have_seen_events_dict( - self, keys: Collection[Tuple[str, str]] - ) -> Dict[Tuple[str, str], bool]: + self, + room_id: str, + event_ids: Collection[str], + ) -> Dict[str, bool]: """Helper for have_seen_events Returns: - a dict {(room_id, event_id)-> bool} + a dict {event_id -> bool} """ # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) - for (rid, eid) in keys - if await self._get_event_cache.contains((eid,)) + event_id + for event_id in event_ids + if await self._get_event_cache.contains((event_id,)) } results = dict.fromkeys(cache_results, True) - remaining = [k for k in keys if k not in cache_results] + remaining = [ + event_id for event_id in event_ids if event_id not in cache_results + ] if not remaining: return results @@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] + txn.database_engine, "e.event_id", remaining ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update( - {(rid, eid): (eid in found_events) for (rid, eid) in remaining} - ) + results.update({eid: (eid in found_events) for eid in remaining}) await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: - res = await self._have_seen_events_dict(((room_id, event_id),)) - return res[(room_id, event_id)] + res = await self._have_seen_events_dict(room_id, [event_id]) + return res[event_id] def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 3909f1caea..0391966462 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args + if num_args != self.num_args: + raise Exception( + "Number of args (%s) does not match underlying cache_method_name=%s (%s)." + % (self.num_args, self.cached_method_name, num_args) + ) + @functools.wraps(self.orig) def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 67401272ac..32a798d74b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -35,66 +35,45 @@ from synapse.util import Clock from synapse.util.async_helpers import yieldable_gather_results from tests import unittest +from tests.test_utils.event_injection import create_event, inject_event class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor, clock, hs): + self.hs = hs self.store: EventsWorkerStore = hs.get_datastores().main - # insert some test data - for rid in ("room1", "room2"): - self.get_success( - self.store.db_pool.simple_insert( - "rooms", - {"room_id": rid, "room_version": 4}, - ) - ) + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + self.room_id = self.helper.create_room_as(self.user, tok=self.token) self.event_ids: List[str] = [] - for idx, rid in enumerate( - ( - "room1", - "room1", - "room1", - "room2", - ) - ): - event_json = {"type": f"test {idx}", "room_id": rid} - event = make_event_from_dict(event_json, room_version=RoomVersions.V4) - event_id = event.event_id - - self.get_success( - self.store.db_pool.simple_insert( - "events", - { - "event_id": event_id, - "room_id": rid, - "topological_ordering": idx, - "stream_ordering": idx, - "type": event.type, - "processed": True, - "outlier": False, - }, + for i in range(3): + event = self.get_success( + inject_event( + hs, + room_version=RoomVersions.V7.identifier, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": f"foobarbaz{i}"}, ) ) - self.get_success( - self.store.db_pool.simple_insert( - "event_json", - { - "event_id": event_id, - "room_id": rid, - "json": json.dumps(event_json), - "internal_metadata": "{}", - "format_version": 3, - }, - ) - ) - self.event_ids.append(event_id) + + self.event_ids.append(event.event_id) def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) @@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) @@ -116,11 +97,86 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0]]) + self.store.have_seen_events(self.room_id, [self.event_ids[0]]) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) + def test_persisting_event_invalidates_cache(self): + """ + Test to make sure that the `have_seen_event` cache + is invalidated after we persist an event and returns + the updated value. + """ + event, event_context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": "garply"}, + ) + ) + + with LoggingContext(name="test") as ctx: + # First, check `have_seen_event` for an event we have not seen yet + # to prime the cache with a `false` value. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, set()) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Persist the event which should invalidate or prefill the + # `have_seen_event` cache so we don't return stale values. + persistence = self.hs.get_storage_controllers().persistence + self.get_success( + persistence.persist_event( + event, + event_context, + ) + ) + + with LoggingContext(name="test") as ctx: + # Check `have_seen_event` again and we should see the updated fact + # that we have now seen the event after persisting it. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, {event.event_id}) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + def test_invalidate_cache_by_room_id(self): + """ + Test to make sure that all events associated with the given `(room_id,)` + are invalidated in the `have_seen_event` cache. + """ + with LoggingContext(name="test") as ctx: + # Prime the cache with some values + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Clear the cache with any events associated with the `room_id` + self.store.have_seen_event.invalidate((self.room_id,)) + + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # Since we cleared the cache, it should result in another db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + class EventCacheTestCase(unittest.HomeserverTestCase): """Test that the various layers of event cache works.""" diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 48e616ac74..90861fe522 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Set +from typing import Iterable, Set, Tuple from unittest import mock from twisted.internet import defer, reactor @@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.inner_context_was_finished, "Tried to restart a finished logcontext" ) self.assertEqual(current_context(), SENTINEL_CONTEXT) + + def test_num_args_mismatch(self): + """ + Make sure someone does not accidentally use @cachedList on a method with + a mismatch in the number args to the underlying single cache method. + """ + + class Cls: + @descriptors.cached(tree=True) + def fn(self, room_id, event_id): + pass + + # This is wrong ❌. `@cachedList` expects to be given the same number + # of arguments as the underlying cached function, just with one of + # the arguments being an iterable + @descriptors.cachedList(cached_method_name="fn", list_name="keys") + def list_fn(self, keys: Iterable[Tuple[str, str]]): + pass + + # Corrected syntax ✅ + # + # @cachedList(cached_method_name="fn", list_name="event_ids") + # async def list_fn( + # self, room_id: str, event_ids: Collection[str], + # ) + + obj = Cls() + + # Make sure this raises an error about the arg mismatch + with self.assertRaises(Exception): + obj.list_fn([("foo", "bar")]) -- cgit 1.5.1 From a2cf66a94d5dfd9d6496ac3e48ec9a22f17be69a Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 02:39:03 -0700 Subject: Prepatory work for batching events to send (#13487) This PR begins work on batching up events during the creation of a room. The PR splits out the creation and sending/persisting of the events. The first three events in the creation of the room-creating the room, joining the creator to the room, and the power levels event are sent sequentially, while the subsequent events are created and collected to be sent at the end of the function. This is currently done by appending them to a list and then iterating over the list to send, the next step (after this PR) would be to send and persist the collected events as a batch. --- changelog.d/13487.misc | 1 + synapse/handlers/message.py | 175 ++++++++++++++++++++++++++-------------- synapse/handlers/room.py | 155 ++++++++++++++++++++++++----------- synapse/state/__init__.py | 63 +++++++++++++++ tests/rest/client/test_rooms.py | 4 +- 5 files changed, 290 insertions(+), 108 deletions(-) create mode 100644 changelog.d/13487.misc (limited to 'synapse') diff --git a/changelog.d/13487.misc b/changelog.d/13487.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13487.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e07cda133a..062f93bc67 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -63,6 +63,7 @@ from synapse.types import ( MutableStateMap, Requester, RoomAlias, + StateMap, StreamToken, UserID, create_requester, @@ -567,9 +568,17 @@ class EventCreationHandler: outlier: bool = False, historical: bool = False, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ - Given a dict from a client, create a new event. + Given a dict from a client, create a new event. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Creates an FrozenEvent object, filling out auth_events, prev_events, etc. @@ -612,16 +621,27 @@ class EventCreationHandler: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Raises: ResourceLimitError if server is blocked to some resource being exceeded + Returns: Tuple of created event, Context """ @@ -693,6 +713,9 @@ class EventCreationHandler: auth_event_ids=auth_event_ids, state_event_ids=state_event_ids, depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -707,10 +730,14 @@ class EventCreationHandler: # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) - ) - prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) + if for_batch: + assert state_map is not None + prev_event_id = state_map.get((EventTypes.Member, event.sender)) + else: + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, None)]) + ) + prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id @@ -1009,8 +1036,16 @@ class EventCreationHandler: auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: - """Create a new event for a local client + """Create a new event for a local client. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Args: builder: @@ -1043,6 +1078,14 @@ class EventCreationHandler: Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Returns: Tuple of created event, context """ @@ -1095,64 +1138,76 @@ class EventCreationHandler: builder.type == EventTypes.Create or prev_event_ids ), "Attempting to create a non-m.room.create event with no prev_events" - event = await builder.build( - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - depth=depth, - ) + if for_batch: + assert prev_event_ids is not None + assert state_map is not None + assert current_state_group is not None + auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + ) + context = await self.state.compute_event_context_for_batched( + event, state_map, current_state_group + ) + else: + event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, + ) - # Pass on the outlier property from the builder to the event - # after it is created - if builder.internal_metadata.outlier: - event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage_controllers) - elif ( - event.type == EventTypes.MSC2716_INSERTION - and state_event_ids - and builder.internal_metadata.is_historical() - ): - # Add explicit state to the insertion event so it has state to derive - # from even though it's floating with no `prev_events`. The rest of - # the batch can derive from this state and state_group. - # - # TODO(faster_joins): figure out how this works, and make sure that the - # old state is complete. - # https://github.com/matrix-org/synapse/issues/13003 - metadata = await self.store.get_metadata_for_events(state_event_ids) - - state_map_for_event: MutableStateMap[str] = {} - for state_id in state_event_ids: - data = metadata.get(state_id) - if data is None: - # We're trying to persist a new historical batch of events - # with the given state, e.g. via - # `RoomBatchSendEventRestServlet`. The state can be inferred - # by Synapse or set directly by the client. - # - # Either way, we should have persisted all the state before - # getting here. - raise Exception( - f"State event {state_id} not found in DB," - " Synapse should have persisted it before using it." - ) + # Pass on the outlier property from the builder to the event + # after it is created + if builder.internal_metadata.outlier: + event.internal_metadata.outlier = True + context = EventContext.for_outlier(self._storage_controllers) + elif ( + event.type == EventTypes.MSC2716_INSERTION + and state_event_ids + and builder.internal_metadata.is_historical() + ): + # Add explicit state to the insertion event so it has state to derive + # from even though it's floating with no `prev_events`. The rest of + # the batch can derive from this state and state_group. + # + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. + # https://github.com/matrix-org/synapse/issues/13003 + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map_for_event: MutableStateMap[str] = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) - if data.state_key is None: - raise Exception( - f"Trying to set non-state event {state_id} as state" - ) + if data.state_key is None: + raise Exception( + f"Trying to set non-state event {state_id} as state" + ) - state_map_for_event[(data.event_type, data.state_key)] = state_id + state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( - event, - state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 - partial_state=False, - ) - else: - context = await self.state.compute_event_context(event) + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map_for_event, + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + partial_state=False, + ) + else: + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 33e9a87002..09a1a82e6c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -716,7 +716,7 @@ class RoomCreationHandler: if ( self._server_notices_mxid is not None - and requester.user.to_string() == self._server_notices_mxid + and user_id == self._server_notices_mxid ): # allow the server notices mxid to create rooms is_requester_admin = True @@ -1042,7 +1042,9 @@ class RoomCreationHandler: creator_join_profile: Optional[JsonDict] = None, ratelimit: bool = True, ) -> Tuple[int, str, int]: - """Sends the initial events into a new room. + """Sends the initial events into a new room. Sends the room creation, membership, + and power level events into the room sequentially, then creates and batches up the + rest of the events to persist as a batch to the DB. `power_level_content_override` doesn't apply when initial state has power level state event content. @@ -1053,13 +1055,21 @@ class RoomCreationHandler: """ creator_id = creator.user.to_string() - event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} - depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None - - def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: + # the most recently created event + prev_event: List[str] = [] + # a map of event types, state keys -> event_ids. We collect these mappings this as events are + # created (but not persisted to the db) to determine state for future created events + # (as this info can't be pulled from the db) + state_map: MutableStateMap[str] = {} + # current_state_group of last event created. Used for computing event context of + # events to be batched + current_state_group = None + + def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} e.update(event_keys) @@ -1067,32 +1077,52 @@ class RoomCreationHandler: return e - async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: - nonlocal last_sent_event_id + async def create_event( + etype: str, + content: JsonDict, + for_batch: bool, + **kwargs: Any, + ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: nonlocal depth + nonlocal prev_event - event = create(etype, content, **kwargs) - logger.debug("Sending %s in new room", etype) - # Allow these events to be sent even if the user is shadow-banned to - # allow the room creation to complete. - ( - sent_event, - last_stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( + event_dict = create_event_dict(etype, content, **kwargs) + + new_event, new_context = await self.event_creation_handler.create_event( creator, - event, + event_dict, + prev_event_ids=prev_event, + depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, + ) + depth += 1 + prev_event = [new_event.event_id] + state_map[(new_event.type, new_event.state_key)] = new_event.event_id + + return new_event, new_context + + async def send( + event: EventBase, + context: synapse.events.snapshot.EventContext, + creator: Requester, + ) -> int: + nonlocal last_sent_event_id + + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + event=event, + context=context, ratelimit=False, ignore_shadow_ban=True, - # Note: we don't pass state_event_ids here because this triggers - # an additional query per event to look them up from the events table. - prev_event_ids=[last_sent_event_id] if last_sent_event_id else [], - depth=depth, ) - last_sent_event_id = sent_event.event_id - depth += 1 + last_sent_event_id = ev.event_id - return last_stream_id + # we know it was persisted, so must have a stream ordering + assert ev.internal_metadata.stream_ordering + return ev.internal_metadata.stream_ordering try: config = self._presets_dict[preset_config] @@ -1102,9 +1132,13 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - await send(etype=EventTypes.Create, content=creation_content) + creation_event, creation_context = await create_event( + EventTypes.Create, creation_content, False + ) logger.debug("Sending %s in new room", EventTypes.Member) + await send(creation_event, creation_context, creator) + # Room create event must exist at this point assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( @@ -1119,14 +1153,22 @@ class RoomCreationHandler: depth=depth, ) last_sent_event_id = member_event_id + prev_event = [member_event_id] + + # update the depth and state map here as the membership event has been created + # through a different code path + depth += 1 + state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=pl_content + power_event, power_context = await create_event( + EventTypes.PowerLevels, pl_content, False ) + current_state_group = power_context._state_group + last_sent_stream_id = await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1169,47 +1211,68 @@ class RoomCreationHandler: # apply those. if power_level_content_override: power_level_content.update(power_level_content_override) - - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=power_level_content + pl_event, pl_context = await create_event( + EventTypes.PowerLevels, + power_level_content, + False, ) + current_state_group = pl_context._state_group + last_sent_stream_id = await send(pl_event, pl_context, creator) + events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.CanonicalAlias, - content={"alias": room_alias.to_string()}, + room_alias_event, room_alias_context = await create_event( + EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) + current_state_group = room_alias_context._state_group + events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} + join_rules_event, join_rules_context = await create_event( + EventTypes.JoinRules, + {"join_rule": config["join_rules"]}, + True, ) + current_state_group = join_rules_context._state_group + events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]}, + visibility_event, visibility_context = await create_event( + EventTypes.RoomHistoryVisibility, + {"history_visibility": config["history_visibility"]}, + True, ) + current_state_group = visibility_context._state_group + events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.GuestAccess, - content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + guest_access_event, guest_access_context = await create_event( + EventTypes.GuestAccess, + {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + True, ) + current_state_group = guest_access_context._state_group + events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): - last_sent_stream_id = await send( - etype=etype, state_key=state_key, content=content + event, context = await create_event( + etype, content, True, state_key=state_key ) + current_state_group = context._state_group + events_to_send.append((event, context)) if config["encrypted"]: - last_sent_stream_id = await send( - etype=EventTypes.RoomEncryption, + encryption_event, encryption_context = await create_event( + EventTypes.RoomEncryption, + {"algorithm": RoomEncryptionAlgorithms.DEFAULT}, + True, state_key="", - content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, ) + events_to_send.append((encryption_event, encryption_context)) + for event, context in events_to_send: + last_sent_stream_id = await send(event, context, creator) return last_sent_stream_id, last_sent_event_id, depth def _generate_room_id(self) -> str: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b24..6f3dd0463e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -420,6 +420,69 @@ class StateHandler: partial_state=partial_state, ) + async def compute_event_context_for_batched( + self, + event: EventBase, + state_ids_before_event: StateMap[str], + current_state_group: int, + ) -> EventContext: + """ + Generate an event context for an event that has not yet been persisted to the + database. Intended for use with events that are created to be persisted in a batch. + Args: + event: the event the context is being computed for + state_ids_before_event: a state map consisting of the state ids of the events + created prior to this event. + current_state_group: the current state group before the event. + """ + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None + + state_group_before_event = current_state_group + + # if the event is not state, we are set + if not event.is_state(): + return EventContext.with_state( + storage=self._storage_controllers, + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + state_delta_due_to_event={}, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + partial_state=False, + ) + + # otherwise, we'll need to create a new state group for after the event + key = (event.type, event.state_key) + + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + + delta_ids = {key: event.event_id} + + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=None, + ) + ) + + return EventContext.with_state( + storage=self._storage_controllers, + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + state_delta_due_to_event=delta_ids, + prev_group=state_group_before_event, + delta_ids=delta_ids, + partial_state=False, + ) + @measure_func() async def resolve_state_groups_for_events( self, room_id: str, event_ids: Collection[str], await_full_state: bool = True diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index c7eb88d33f..e281aef779 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(44, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(50, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From 8ab16a92edd675453c78cfd9974081e374b0f998 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 03:11:48 -0700 Subject: Persist CreateRoom events to DB in a batch (#13800) --- changelog.d/13800.misc | 1 + synapse/handlers/message.py | 663 +++++++++++++++++--------------- synapse/handlers/room.py | 21 +- synapse/handlers/room_batch.py | 3 +- synapse/handlers/room_member.py | 11 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/send_event.py | 4 +- synapse/replication/http/send_events.py | 171 ++++++++ tests/handlers/test_message.py | 10 +- tests/handlers/test_register.py | 4 +- tests/storage/test_event_chain.py | 8 +- tests/unittest.py | 4 +- 12 files changed, 563 insertions(+), 339 deletions(-) create mode 100644 changelog.d/13800.misc create mode 100644 synapse/replication/http/send_events.py (limited to 'synapse') diff --git a/changelog.d/13800.misc b/changelog.d/13800.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13800.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 062f93bc67..00e7645ba5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -56,11 +56,13 @@ from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, + PersistedEventPosition, Requester, RoomAlias, StateMap, @@ -493,6 +495,7 @@ class EventCreationHandler: self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) + self.send_events = ReplicationSendEventsRestServlet.make_client(hs) self.request_ratelimiter = hs.get_request_ratelimiter() @@ -1016,8 +1019,7 @@ class EventCreationHandler: ev = await self.handle_new_client_event( requester=requester, - event=event, - context=context, + events_and_context=[(event, context)], ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, ) @@ -1293,13 +1295,13 @@ class EventCreationHandler: async def handle_new_client_event( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ignore_shadow_ban: bool = False, ) -> EventBase: - """Processes a new event. + """Processes new events. Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. This includes deduplicating, checking auth, persisting, notifying users, sending to remote servers, etc. @@ -1309,8 +1311,7 @@ class EventCreationHandler: Args: requester - event - context + events_and_context: A list of one or more tuples of event, context to be persisted ratelimit extra_users: Any extra users to notify about event @@ -1328,62 +1329,63 @@ class EventCreationHandler: """ extra_users = extra_users or [] - # we don't apply shadow-banning to membership events here. Invites are blocked - # higher up the stack, and we allow shadow-banned users to send join and leave - # events as normal. - if ( - event.type != EventTypes.Member - and not ignore_shadow_ban - and requester.shadow_banned - ): - # We randomly sleep a bit just to annoy the requester. - await self.clock.sleep(random.randint(1, 10)) - raise ShadowBanError() + for event, context in events_and_context: + # we don't apply shadow-banning to membership events here. Invites are blocked + # higher up the stack, and we allow shadow-banned users to send join and leave + # events as normal. + if ( + event.type != EventTypes.Member + and not ignore_shadow_ban + and requester.shadow_banned + ): + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() - if event.is_state(): - prev_event = await self.deduplicate_state_event(event, context) - if prev_event is not None: - logger.info( - "Not bothering to persist state event %s duplicated by %s", - event.event_id, - prev_event.event_id, - ) - return prev_event + if event.is_state(): + prev_event = await self.deduplicate_state_event(event, context) + if prev_event is not None: + logger.info( + "Not bothering to persist state event %s duplicated by %s", + event.event_id, + prev_event.event_id, + ) + return prev_event - if event.internal_metadata.is_out_of_band_membership(): - # the only sort of out-of-band-membership events we expect to see here are - # invite rejections and rescinded knocks that we have generated ourselves. - assert event.type == EventTypes.Member - assert event.content["membership"] == Membership.LEAVE - else: - try: - validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) - except AuthError as err: - logger.warning("Denying new event %r because %s", event, err) - raise err + if event.internal_metadata.is_out_of_band_membership(): + # the only sort of out-of-band-membership events we expect to see here are + # invite rejections and rescinded knocks that we have generated ourselves. + assert event.type == EventTypes.Member + assert event.content["membership"] == Membership.LEAVE + else: + try: + validate_event_for_room_version(event) + await self._event_auth_handler.check_auth_rules_from_context( + event, context + ) + except AuthError as err: + logger.warning("Denying new event %r because %s", event, err) + raise err - # Ensure that we can round trip before trying to persist in db - try: - dump = json_encoder.encode(event.content) - json_decoder.decode(dump) - except Exception: - logger.exception("Failed to encode content: %r", event.content) - raise + # Ensure that we can round trip before trying to persist in db + try: + dump = json_encoder.encode(event.content) + json_decoder.decode(dump) + except Exception: + logger.exception("Failed to encode content: %r", event.content) + raise # We now persist the event (and update the cache in parallel, since we # don't want to block on it). + event, context = events_and_context[0] try: result, _ = await make_deferred_yieldable( gather_results( ( run_in_background( - self._persist_event, + self._persist_events, requester=requester, - event=event, - context=context, + events_and_context=events_and_context, ratelimit=ratelimit, extra_users=extra_users, ), @@ -1407,45 +1409,47 @@ class EventCreationHandler: return result - async def _persist_event( + async def _persist_events( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ) -> EventBase: - """Actually persists the event. Should only be called by + """Actually persists new events. Should only be called by `handle_new_client_event`, and see its docstring for documentation of - the arguments. + the arguments. Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + for event, context in events_and_context: + # Skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). + if not event.internal_metadata.is_historical(): + with opentracing.start_active_span("calculate_push_actions"): + await self._bulk_push_rule_evaluator.action_for_event_by_user( + event, context + ) try: # If we're a worker we need to hit out to the master. - writer_instance = self._events_shard_config.get_instance(event.room_id) + first_event, _ = events_and_context[0] + writer_instance = self._events_shard_config.get_instance( + first_event.room_id + ) if writer_instance != self._instance_name: try: - result = await self.send_event( + result = await self.send_events( instance_name=writer_instance, - event_id=event.event_id, + events_and_context=events_and_context, store=self.store, requester=requester, - event=event, - context=context, ratelimit=ratelimit, extra_users=extra_users, ) @@ -1455,6 +1459,11 @@ class EventCreationHandler: raise stream_id = result["stream_id"] event_id = result["event_id"] + + # If we batch persisted events we return the last persisted event, otherwise + # we return the one event that was persisted + event, _ = events_and_context[-1] + if event_id != event.event_id: # If we get a different event back then it means that its # been de-duplicated, so we replace the given event with the @@ -1467,15 +1476,19 @@ class EventCreationHandler: event.internal_metadata.stream_ordering = stream_id return event - event = await self.persist_and_notify_client_event( - requester, event, context, ratelimit=ratelimit, extra_users=extra_users + event = await self.persist_and_notify_client_events( + requester, + events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, ) return event except Exception: - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - await self.store.remove_push_actions_from_staging(event.event_id) + for event, _ in events_and_context: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + await self.store.remove_push_actions_from_staging(event.event_id) raise async def cache_joined_hosts_for_event( @@ -1569,23 +1582,26 @@ class EventCreationHandler: Codes.BAD_ALIAS, ) - async def persist_and_notify_client_event( + async def persist_and_notify_client_events( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ) -> EventBase: - """Called when we have fully built the event, have already - calculated the push actions for the event, and checked auth. + """Called when we have fully built the events, have already + calculated the push actions for the events, and checked auth. This should only be run on the instance in charge of persisting events. + Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. + Returns: - The persisted event. This may be different than the given event if - it was de-duplicated (e.g. because we had already persisted an - event with the same transaction ID.) + The persisted event, if one event is passed in, or the last event in the + list in the case of batch persisting. If only one event was persisted, the + returned event may be different than the given event if it was de-duplicated + (e.g. because we had already persisted an event with the same transaction ID.) Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -1593,277 +1609,297 @@ class EventCreationHandler: """ extra_users = extra_users or [] - assert self._storage_controllers.persistence is not None - assert self._events_shard_config.should_handle( - self._instance_name, event.room_id - ) + for event, context in events_and_context: + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) - if ratelimit: - # We check if this is a room admin redacting an event so that we - # can apply different ratelimiting. We do this by simply checking - # it's not a self-redaction (to avoid having to look up whether the - # user is actually admin or not). - is_admin_redaction = False - if event.type == EventTypes.Redaction: - assert event.redacts is not None + if ratelimit: + # We check if this is a room admin redacting an event so that we + # can apply different ratelimiting. We do this by simply checking + # it's not a self-redaction (to avoid having to look up whether the + # user is actually admin or not). + is_admin_redaction = False + if event.type == EventTypes.Redaction: + assert event.redacts is not None + + original_event = await self.store.get_event( + event.redacts, + redact_behaviour=EventRedactBehaviour.as_is, + get_prev_content=False, + allow_rejected=False, + allow_none=True, + ) - original_event = await self.store.get_event( - event.redacts, - redact_behaviour=EventRedactBehaviour.as_is, - get_prev_content=False, - allow_rejected=False, - allow_none=True, + is_admin_redaction = bool( + original_event and event.sender != original_event.sender + ) + + await self.request_ratelimiter.ratelimit( + requester, is_admin_redaction=is_admin_redaction ) - is_admin_redaction = bool( - original_event and event.sender != original_event.sender + # run checks/actions on event based on type + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + event.state_key, event.room_id ) + if current_membership != Membership.JOIN: + self._notifier.notify_user_joined_room( + event.event_id, event.room_id + ) - await self.request_ratelimiter.ratelimit( - requester, is_admin_redaction=is_admin_redaction - ) + await self._maybe_kick_guest_users(event, context) - if event.type == EventTypes.Member and event.membership == Membership.JOIN: - ( - current_membership, - _, - ) = await self.store.get_local_current_membership_for_user_in_room( - event.state_key, event.room_id - ) - if current_membership != Membership.JOIN: - self._notifier.notify_user_joined_room(event.event_id, event.room_id) + if event.type == EventTypes.CanonicalAlias: + # Validate a newly added alias or newly added alt_aliases. - await self._maybe_kick_guest_users(event, context) + original_alias = None + original_alt_aliases: object = [] - if event.type == EventTypes.CanonicalAlias: - # Validate a newly added alias or newly added alt_aliases. + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_alias_event = await self.store.get_event(original_event_id) - original_alias = None - original_alt_aliases: object = [] + if original_alias_event: + original_alias = original_alias_event.content.get("alias", None) + original_alt_aliases = original_alias_event.content.get( + "alt_aliases", [] + ) - original_event_id = event.unsigned.get("replaces_state") - if original_event_id: - original_event = await self.store.get_event(original_event_id) + # Check the alias is currently valid (if it has changed). + room_alias_str = event.content.get("alias", None) + directory_handler = self.hs.get_directory_handler() + if room_alias_str and room_alias_str != original_alias: + await self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) - if original_event: - original_alias = original_event.content.get("alias", None) - original_alt_aliases = original_event.content.get("alt_aliases", []) - - # Check the alias is currently valid (if it has changed). - room_alias_str = event.content.get("alias", None) - directory_handler = self.hs.get_directory_handler() - if room_alias_str and room_alias_str != original_alias: - await self._validate_canonical_alias( - directory_handler, room_alias_str, event.room_id - ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, + "The alt_aliases property must be a list.", + Codes.INVALID_PARAM, + ) - # Check that alt_aliases is the proper form. - alt_aliases = event.content.get("alt_aliases", []) - if not isinstance(alt_aliases, (list, tuple)): - raise SynapseError( - 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM - ) + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + # TODO: check that the original_alt_aliases' entries are all strings + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + await self._validate_canonical_alias( + directory_handler, alias_str, event.room_id + ) - # If the old version of alt_aliases is of an unknown form, - # completely replace it. - if not isinstance(original_alt_aliases, (list, tuple)): - # TODO: check that the original_alt_aliases' entries are all strings - original_alt_aliases = [] + federation_handler = self.hs.get_federation_handler() - # Check that each alias is currently valid. - new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) - if new_alt_aliases: - for alias_str in new_alt_aliases: - await self._validate_canonical_alias( - directory_handler, alias_str, event.room_id + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.INVITE: + event.unsigned[ + "invite_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, ) - federation_handler = self.hs.get_federation_handler() + invitee = UserID.from_string(event.state_key) + if not self.hs.is_mine(invitee): + # TODO: Can we add signature from remote server in a nicer + # way? If we have been invited by a remote server, we need + # to get them to sign the event. - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, - ) + returned_invite = await federation_handler.send_invite( + invitee.domain, event + ) + event.unsigned.pop("room_state", None) - invitee = UserID.from_string(event.state_key) - if not self.hs.is_mine(invitee): - # TODO: Can we add signature from remote server in a nicer - # way? If we have been invited by a remote server, we need - # to get them to sign the event. + # TODO: Make sure the signatures actually are correct. + event.signatures.update(returned_invite.signatures) - returned_invite = await federation_handler.send_invite( - invitee.domain, event + if event.content["membership"] == Membership.KNOCK: + event.unsigned[ + "knock_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, ) - event.unsigned.pop("room_state", None) - # TODO: Make sure the signatures actually are correct. - event.signatures.update(returned_invite.signatures) + if event.type == EventTypes.Redaction: + assert event.redacts is not None - if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + original_event = await self.store.get_event( + event.redacts, + redact_behaviour=EventRedactBehaviour.as_is, + get_prev_content=False, + allow_rejected=False, + allow_none=True, ) - if event.type == EventTypes.Redaction: - assert event.redacts is not None + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - original_event = await self.store.get_event( - event.redacts, - redact_behaviour=EventRedactBehaviour.as_is, - get_prev_content=False, - allow_rejected=False, - allow_none=True, - ) + # we can make some additional checks now if we have the original event. + if original_event: + if original_event.type == EventTypes.Create: + raise AuthError(403, "Redacting create events is not permitted") - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # we can make some additional checks now if we have the original event. - if original_event: - if original_event.type == EventTypes.Create: - raise AuthError(403, "Redacting create events is not permitted") - - if original_event.room_id != event.room_id: - raise SynapseError(400, "Cannot redact event from a different room") - - if original_event.type == EventTypes.ServerACL: - raise AuthError(403, "Redacting server ACL events is not permitted") - - # Add a little safety stop-gap to prevent people from trying to - # redact MSC2716 related events when they're in a room version - # which does not support it yet. We allow people to use MSC2716 - # events in existing room versions but only from the room - # creator since it does not require any changes to the auth - # rules and in effect, the redaction algorithm . In the - # supported room version, we add the `historical` power level to - # auth the MSC2716 related events and adjust the redaction - # algorthim to keep the `historical` field around (redacting an - # event should only strip fields which don't affect the - # structural protocol level). - is_msc2716_event = ( - original_event.type == EventTypes.MSC2716_INSERTION - or original_event.type == EventTypes.MSC2716_BATCH - or original_event.type == EventTypes.MSC2716_MARKER - ) - if not room_version_obj.msc2716_historical and is_msc2716_event: - raise AuthError( - 403, - "Redacting MSC2716 events is not supported in this room version", - ) + if original_event.room_id != event.room_id: + raise SynapseError( + 400, "Cannot redact event from a different room" + ) - event_types = event_auth.auth_types_for_event(event.room_version, event) - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types(event_types) - ) + if original_event.type == EventTypes.ServerACL: + raise AuthError( + 403, "Redacting server ACL events is not permitted" + ) - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_map = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} + # Add a little safety stop-gap to prevent people from trying to + # redact MSC2716 related events when they're in a room version + # which does not support it yet. We allow people to use MSC2716 + # events in existing room versions but only from the room + # creator since it does not require any changes to the auth + # rules and in effect, the redaction algorithm . In the + # supported room version, we add the `historical` power level to + # auth the MSC2716 related events and adjust the redaction + # algorthim to keep the `historical` field around (redacting an + # event should only strip fields which don't affect the + # structural protocol level). + is_msc2716_event = ( + original_event.type == EventTypes.MSC2716_INSERTION + or original_event.type == EventTypes.MSC2716_BATCH + or original_event.type == EventTypes.MSC2716_MARKER + ) + if not room_version_obj.msc2716_historical and is_msc2716_event: + raise AuthError( + 403, + "Redacting MSC2716 events is not supported in this room version", + ) - if event_auth.check_redaction( - room_version_obj, event, auth_events=auth_events - ): - # this user doesn't have 'redact' rights, so we need to do some more - # checks on the original event. Let's start by checking the original - # event exists. - if not original_event: - raise NotFoundError("Could not find event %s" % (event.redacts,)) - - if event.user_id != original_event.user_id: - raise AuthError(403, "You don't have permission to redact events") - - # all the checks are done. - event.internal_metadata.recheck_redaction = False - - if event.type == EventTypes.Create: - prev_state_ids = await context.get_prev_state_ids() - if prev_state_ids: - raise AuthError(403, "Changing the room create event is forbidden") - - if event.type == EventTypes.MSC2716_INSERTION: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - create_event = await self.store.get_create_event_for_room(event.room_id) - room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) - - # Only check an insertion event if the room version - # supports it or the event is from the room creator. - if room_version_obj.msc2716_historical or ( - self.config.experimental.msc2716_enabled - and event.sender == room_creator - ): - next_batch_id = event.content.get( - EventContentFields.MSC2716_NEXT_BATCH_ID + event_types = event_auth.auth_types_for_event(event.room_version, event) + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types(event_types) ) - conflicting_insertion_event_id = None - if next_batch_id: - conflicting_insertion_event_id = ( - await self.store.get_insertion_event_id_by_batch_id( - event.room_id, next_batch_id + + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + if event_auth.check_redaction( + room_version_obj, event, auth_events=auth_events + ): + # this user doesn't have 'redact' rights, so we need to do some more + # checks on the original event. Let's start by checking the original + # event exists. + if not original_event: + raise NotFoundError( + "Could not find event %s" % (event.redacts,) ) + + if event.user_id != original_event.user_id: + raise AuthError( + 403, "You don't have permission to redact events" + ) + + # all the checks are done. + event.internal_metadata.recheck_redaction = False + + if event.type == EventTypes.Create: + prev_state_ids = await context.get_prev_state_ids() + if prev_state_ids: + raise AuthError(403, "Changing the room create event is forbidden") + + if event.type == EventTypes.MSC2716_INSERTION: + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + create_event = await self.store.get_create_event_for_room(event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + # Only check an insertion event if the room version + # supports it or the event is from the room creator. + if room_version_obj.msc2716_historical or ( + self.config.experimental.msc2716_enabled + and event.sender == room_creator + ): + next_batch_id = event.content.get( + EventContentFields.MSC2716_NEXT_BATCH_ID ) - if conflicting_insertion_event_id is not None: - # The current insertion event that we're processing is invalid - # because an insertion event already exists in the room with the - # same next_batch_id. We can't allow multiple because the batch - # pointing will get weird, e.g. we can't determine which insertion - # event the batch event is pointing to. - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Another insertion event already exists with the same next_batch_id", - errcode=Codes.INVALID_PARAM, - ) + conflicting_insertion_event_id = None + if next_batch_id: + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_id_by_batch_id( + event.room_id, next_batch_id + ) + ) + if conflicting_insertion_event_id is not None: + # The current insertion event that we're processing is invalid + # because an insertion event already exists in the room with the + # same next_batch_id. We can't allow multiple because the batch + # pointing will get weird, e.g. we can't determine which insertion + # event the batch event is pointing to. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Another insertion event already exists with the same next_batch_id", + errcode=Codes.INVALID_PARAM, + ) - # Mark any `m.historical` messages as backfilled so they don't appear - # in `/sync` and have the proper decrementing `stream_ordering` as we import - backfilled = False - if event.internal_metadata.is_historical(): - backfilled = True + # Mark any `m.historical` messages as backfilled so they don't appear + # in `/sync` and have the proper decrementing `stream_ordering` as we import + backfilled = False + if event.internal_metadata.is_historical(): + backfilled = True - # Note that this returns the event that was persisted, which may not be - # the same as we passed in if it was deduplicated due transaction IDs. + assert self._storage_controllers.persistence is not None ( - event, - event_pos, + persisted_events, max_stream_token, - ) = await self._storage_controllers.persistence.persist_event( - event, context=context, backfilled=backfilled + ) = await self._storage_controllers.persistence.persist_events( + events_and_context, backfilled=backfilled ) - if self._ephemeral_events_enabled: - # If there's an expiry timestamp on the event, schedule its expiry. - self._message_handler.maybe_schedule_expiry(event) + for event in persisted_events: + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) - async def _notify() -> None: - try: - await self.notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users - ) - except Exception: - logger.exception( - "Error notifying about new room event %s", - event.event_id, - ) + stream_ordering = event.internal_metadata.stream_ordering + assert stream_ordering is not None + pos = PersistedEventPosition(self._instance_name, stream_ordering) + + async def _notify() -> None: + try: + await self.notifier.on_new_room_event( + event, pos, max_stream_token, extra_users=extra_users + ) + except Exception: + logger.exception( + "Error notifying about new room event %s", + event.event_id, + ) - run_in_background(_notify) + run_in_background(_notify) - if event.type == EventTypes.Message: - # We don't want to block sending messages on any presence code. This - # matters as sometimes presence code can take a while. - run_in_background(self._bump_active_time, requester.user) + if event.type == EventTypes.Message: + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. + run_in_background(self._bump_active_time, requester.user) - return event + return persisted_events[-1] async def _maybe_kick_guest_users( self, event: EventBase, context: EventContext @@ -1952,8 +1988,7 @@ class EventCreationHandler: # shadow-banned user. await self.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], ratelimit=False, ignore_shadow_ban=True, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 09a1a82e6c..b220238e55 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -301,8 +301,7 @@ class RoomCreationHandler: # now send the tombstone await self.event_creation_handler.handle_new_client_event( requester=requester, - event=tombstone_event, - context=tombstone_context, + events_and_context=[(tombstone_event, tombstone_context)], ) state_filter = StateFilter.from_types( @@ -1057,8 +1056,10 @@ class RoomCreationHandler: creator_id = creator.user.to_string() event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None + # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1112,8 +1113,7 @@ class RoomCreationHandler: ev = await self.event_creation_handler.handle_new_client_event( requester=creator, - event=event, - context=context, + events_and_context=[(event, context)], ratelimit=False, ignore_shadow_ban=True, ) @@ -1152,7 +1152,6 @@ class RoomCreationHandler: prev_event_ids=[last_sent_event_id], depth=depth, ) - last_sent_event_id = member_event_id prev_event = [member_event_id] # update the depth and state map here as the membership event has been created @@ -1168,7 +1167,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - last_sent_stream_id = await send(power_event, power_context, creator) + await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1217,7 +1216,7 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - last_sent_stream_id = await send(pl_event, pl_context, creator) + await send(pl_event, pl_context, creator) events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: @@ -1271,9 +1270,11 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) - for event, context in events_to_send: - last_sent_stream_id = await send(event, context, creator) - return last_sent_stream_id, last_sent_event_id, depth + last_event = await self.event_creation_handler.handle_new_client_event( + creator, events_to_send, ignore_shadow_ban=True + ) + assert last_event.internal_metadata.stream_ordering is not None + return last_event.internal_metadata.stream_ordering, last_event.event_id, depth def _generate_room_id(self) -> str: """Generates a random room ID. diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 1414e575d6..411a6fb22f 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -379,8 +379,7 @@ class RoomBatchHandler: await self.create_requester_for_user_id_from_app_service( event.sender, app_service_requester.app_service ), - event=event, - context=context, + events_and_context=[(event, context)], ) return event_ids diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8d01f4bf2b..88158822e0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -432,8 +432,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): with opentracing.start_active_span("handle_new_client_event"): result_event = await self.event_creation_handler.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], extra_users=[target], ratelimit=ratelimit, ) @@ -1252,7 +1251,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target_user], ratelimit=ratelimit + requester, + events_and_context=[(event, context)], + extra_users=[target_user], + ratelimit=ratelimit, ) prev_member_event_id = prev_state_ids.get( @@ -1860,8 +1862,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): result_event = await self.event_creation_handler.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 53aa7fa4c6..ac9a92240a 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -25,6 +25,7 @@ from synapse.replication.http import ( push, register, send_event, + send_events, state, streams, ) @@ -43,6 +44,7 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs: "HomeServer") -> None: send_event.register_servlets(hs, self) + send_events.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) membership.register_servlets(hs, self) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 486f04723c..4215a1c1bc 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -141,8 +141,8 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - event = await self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, ratelimit=ratelimit, extra_users=extra_users + event = await self.event_creation_handler.persist_and_notify_client_events( + requester, [(event, context)], ratelimit=ratelimit, extra_users=extra_users ) return ( diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py new file mode 100644 index 0000000000..8889bbb644 --- /dev/null +++ b/synapse/replication/http/send_events.py @@ -0,0 +1,171 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, List, Tuple + +from twisted.web.server import Request + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, Requester, UserID +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + +logger = logging.getLogger(__name__) + + +class ReplicationSendEventsRestServlet(ReplicationEndpoint): + """Handles batches of newly created events on workers, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/send_events/:txn_id + + { + "events": [{ + "event": { .. serialized event .. }, + "room_version": .., // "1", "2", "3", etc: the version of the room + // containing the event + "event_format_version": .., // 1,2,3 etc: the event format version + "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + "requester": { .. serialized requester .. }, + "ratelimit": true, + }] + } + + 200 OK + + { "stream_id": 12345, "event_id": "$abcdef..." } + + Responds with a 409 when a `PartialStateConflictError` is raised due to an event + context that needs to be recomputed due to the un-partial stating of a room. + + """ + + NAME = "send_events" + PATH_ARGS = () + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + events_and_context: List[Tuple[EventBase, EventContext]], + store: "DataStore", + requester: Requester, + ratelimit: bool, + extra_users: List[UserID], + ) -> JsonDict: + """ + Args: + store + requester + events_and_ctx + ratelimit + """ + serialized_events = [] + + for event, context in events_and_context: + serialized_context = await context.serialize(event, store) + serialized_event = { + "event": event.get_pdu_json(), + "room_version": event.room_version.identifier, + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } + serialized_events.append(serialized_event) + + payload = {"events": serialized_events} + + return payload + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + with Measure(self.clock, "repl_send_events_parse"): + payload = parse_json_object_from_request(request) + events_and_context = [] + events = payload["events"] + + for event_payload in events: + event_dict = event_payload["event"] + room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]] + internal_metadata = event_payload["internal_metadata"] + rejected_reason = event_payload["rejected_reason"] + + event = make_event_from_dict( + event_dict, room_ver, internal_metadata, rejected_reason + ) + event.internal_metadata.outlier = event_payload["outlier"] + + requester = Requester.deserialize( + self.store, event_payload["requester"] + ) + context = EventContext.deserialize( + self._storage_controllers, event_payload["context"] + ) + + ratelimit = event_payload["ratelimit"] + events_and_context.append((event, context)) + + extra_users = [ + UserID.from_string(u) for u in event_payload["extra_users"] + ] + + logger.info( + "Got batch of events to send, last ID of batch is: %s, sending into room: %s", + event.event_id, + event.room_id, + ) + + last_event = ( + await self.event_creation_handler.persist_and_notify_client_events( + requester, events_and_context, ratelimit, extra_users + ) + ) + + return ( + 200, + { + "stream_id": last_event.internal_metadata.stream_ordering, + "event_id": last_event.event_id, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationSendEventsRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 986b50ce0c..99384837d0 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -105,7 +105,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): event1, context = self._create_duplicate_event(txn_id) ret_event1 = self.get_success( - self.handler.handle_new_client_event(self.requester, event1, context) + self.handler.handle_new_client_event( + self.requester, + events_and_context=[(event1, context)], + ) ) stream_id1 = ret_event1.internal_metadata.stream_ordering @@ -118,7 +121,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) ret_event2 = self.get_success( - self.handler.handle_new_client_event(self.requester, event2, context) + self.handler.handle_new_client_event( + self.requester, + events_and_context=[(event2, context)], + ) ) stream_id2 = ret_event2.internal_metadata.stream_ordering diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 86b3d51975..765df75d91 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -497,7 +497,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - event_creation_handler.handle_new_client_event(requester, event, context) + event_creation_handler.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) # Register a second user, which won't be be in the room (or even have an invite) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index a0ce077a99..de9f4af2de 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -531,7 +531,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) self.get_success( - event_handler.handle_new_client_event(self.requester, event, context) + event_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) ) state1 = set(self.get_success(context.get_current_state_ids()).values()) @@ -549,7 +551,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) self.get_success( - event_handler.handle_new_client_event(self.requester, event, context) + event_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) ) state2 = set(self.get_success(context.get_current_state_ids()).values()) diff --git a/tests/unittest.py b/tests/unittest.py index 00cb023198..5116be338e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -734,7 +734,9 @@ class HomeserverTestCase(TestCase): event.internal_metadata.soft_failed = True self.get_success( - event_creator.handle_new_client_event(requester, event, context) + event_creator.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) return event.event_id -- cgit 1.5.1 From 6caa3030835f879724c003a5b0dc66a6285451d8 Mon Sep 17 00:00:00 2001 From: Kateřina Churanová Date: Wed, 28 Sep 2022 14:31:53 +0200 Subject: fix: Push notifications for invite over federation (#13719) --- changelog.d/13719.bugfix | 1 + synapse/events/__init__.py | 4 ++++ synapse/handlers/federation.py | 13 ++++++++++--- synapse/handlers/federation_event.py | 1 + synapse/push/bulk_push_rule_evaluator.py | 10 +++++++--- synapse/push/push_rule_evaluator.py | 16 ++++++++-------- synapse/storage/controllers/persist_events.py | 10 ++++++---- synapse/storage/databases/main/events.py | 10 +++++----- 8 files changed, 42 insertions(+), 23 deletions(-) create mode 100644 changelog.d/13719.bugfix (limited to 'synapse') diff --git a/changelog.d/13719.bugfix b/changelog.d/13719.bugfix new file mode 100644 index 0000000000..4318f4daff --- /dev/null +++ b/changelog.d/13719.bugfix @@ -0,0 +1 @@ +Send invite push notifications for invite over federation. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index b2c9119fd0..030c3ca408 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -289,6 +289,10 @@ class _EventInternalMetadata: """ return self._dict.get("historical", False) + def is_notifiable(self) -> bool: + """Whether this event can trigger a push notification""" + return not self.is_outlier() or self.is_out_of_band_membership() + class EventBase(metaclass=abc.ABCMeta): @property diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 74580f60df..8f847ff845 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -149,6 +149,7 @@ class FederationHandler: self.http_client = hs.get_proxied_blacklisted_http_client() self._replication = hs.get_replication_data_handler() self._federation_event_handler = hs.get_federation_event_handler() + self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs @@ -956,9 +957,15 @@ class FederationHandler: ) context = EventContext.for_outlier(self._storage_controllers) - await self._federation_event_handler.persist_events_and_notify( - event.room_id, [(event, context)] - ) + + await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + try: + await self._federation_event_handler.persist_events_and_notify( + event.room_id, [(event, context)] + ) + except Exception: + await self.store.remove_push_actions_from_staging(event.event_id) + raise return event diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 2d7cde7506..3fac256881 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2170,6 +2170,7 @@ class FederationEventHandler: if instance != self._instance_name: # Limit the number of events sent over replication. We choose 200 # here as that is what we default to in `max_request_body_size(..)` + result = {} try: for batch in batch_iter(event_and_contexts, 200): result = await self._send_events( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 404379ef67..32313e3bcf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -173,7 +173,11 @@ class BulkPushRuleEvaluator: async def _get_power_levels_and_sender_level( self, event: EventBase, context: EventContext - ) -> Tuple[dict, int]: + ) -> Tuple[dict, Optional[int]]: + # There are no power levels and sender levels possible to get from outlier + if event.internal_metadata.is_outlier(): + return {}, None + event_types = auth_types_for_event(event.room_version, event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types(event_types) @@ -250,8 +254,8 @@ class BulkPushRuleEvaluator: should increment the unread count, and insert the results into the event_push_actions_staging table. """ - if event.internal_metadata.is_outlier(): - # This can happen due to out of band memberships + if not event.internal_metadata.is_notifiable(): + # Push rules for events that aren't notifiable can't be processed by this return # Disable counting as unread unless the experimental configuration is diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3c5632cd91..f8176c5a42 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -42,18 +42,18 @@ IS_GLOB = re.compile(r"[\?\*\[\]]") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") -def _room_member_count( - ev: EventBase, condition: Mapping[str, Any], room_member_count: int -) -> bool: +def _room_member_count(condition: Mapping[str, Any], room_member_count: int) -> bool: return _test_ineq_condition(condition, room_member_count) def _sender_notification_permission( - ev: EventBase, condition: Mapping[str, Any], - sender_power_level: int, + sender_power_level: Optional[int], power_levels: Dict[str, Union[int, Dict[str, int]]], ) -> bool: + if sender_power_level is None: + return False + notif_level_key = condition.get("key") if notif_level_key is None: return False @@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent: self, event: EventBase, room_member_count: int, - sender_power_level: int, + sender_power_level: Optional[int], power_levels: Dict[str, Union[int, Dict[str, int]]], relations: Dict[str, Set[Tuple[str, str]]], relations_match_enabled: bool, @@ -198,10 +198,10 @@ class PushRuleEvaluatorForEvent: elif condition["kind"] == "contains_display_name": return self._contains_display_name(display_name) elif condition["kind"] == "room_member_count": - return _room_member_count(self._event, condition, self._room_member_count) + return _room_member_count(condition, self._room_member_count) elif condition["kind"] == "sender_notification_permission": return _sender_notification_permission( - self._event, condition, self._sender_power_level, self._power_levels + condition, self._sender_power_level, self._power_levels ) elif ( condition["kind"] == "org.matrix.msc3772.relation_match" diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 709cb792ed..06e71a8053 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -423,16 +423,18 @@ class EventsPersistenceStorageController: for d in ret_vals: replaced_events.update(d) - events = [] + persisted_events = [] for event, _ in events_and_contexts: existing_event_id = replaced_events.get(event.event_id) if existing_event_id: - events.append(await self.main_store.get_event(existing_event_id)) + persisted_events.append( + await self.main_store.get_event(existing_event_id) + ) else: - events.append(event) + persisted_events.append(event) return ( - events, + persisted_events, self.main_store.get_room_max_token(), ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b59eb7478b..bb489b8189 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2134,13 +2134,13 @@ class PersistEventsStore: appear in events_and_context. """ - # Only non outlier events will have push actions associated with them, + # Only notifiable events will have push actions associated with them, # so let's filter them out. (This makes joining large rooms faster, as # these queries took seconds to process all the state events). - non_outlier_events = [ + notifiable_events = [ event for event, _ in events_and_contexts - if not event.internal_metadata.is_outlier() + if event.internal_metadata.is_notifiable() ] sql = """ @@ -2153,7 +2153,7 @@ class PersistEventsStore: WHERE event_id = ? """ - if non_outlier_events: + if notifiable_events: txn.execute_batch( sql, ( @@ -2163,7 +2163,7 @@ class PersistEventsStore: event.depth, event.event_id, ) - for event in non_outlier_events + for event in notifiable_events ), ) -- cgit 1.5.1 From 4b17a5ace846d82b09fccce79da77a8207a6765f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Sep 2022 14:42:43 +0100 Subject: Handle remote device list updates during partial join (#13913) c.f. #12993 (comment), point 3 This stores all device list updates that we receive while partial joins are ongoing, and processes them once we have the full state. Note: We don't actually process the device lists in the same ways as if we weren't partially joined. Instead of updating the device list remote cache, we simply notify local users that a change in the remote user's devices has happened. I think this is safe as if the local user requests the keys for the remote user and we don't have them we'll simply fetch them as normal. --- changelog.d/13913.misc | 1 + synapse/handlers/device.py | 62 ++++++++++++++++++++++ synapse/handlers/federation.py | 4 ++ synapse/storage/databases/main/devices.py | 55 +++++++++++++++++++ synapse/storage/databases/main/room.py | 20 +++++++ .../delta/73/04pending_device_list_updates.sql | 28 ++++++++++ 6 files changed, 170 insertions(+) create mode 100644 changelog.d/13913.misc create mode 100644 synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql (limited to 'synapse') diff --git a/changelog.d/13913.misc b/changelog.d/13913.misc new file mode 100644 index 0000000000..30b4401049 --- /dev/null +++ b/changelog.d/13913.misc @@ -0,0 +1 @@ +Faster remote room joins: correctly handle remote device list updates during a partial join. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index bad262731c..f2ef591103 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -309,6 +309,17 @@ class DeviceWorkerHandler: "self_signing_key": self_signing_key, } + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + # TODO(faster_joins): worker mode support + # https://github.com/matrix-org/synapse/issues/12994 + logger.error( + "Trying handling device list state for partial join: not supported on workers." + ) + class DeviceHandler(DeviceWorkerHandler): def __init__(self, hs: "HomeServer"): @@ -746,6 +757,15 @@ class DeviceHandler(DeviceWorkerHandler): finally: self._handle_new_device_update_is_processing = False + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + # We defer to the device list updater implementation as we're on the + # right worker. + await self.device_list_updater.handle_room_un_partial_stated(room_id) + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] @@ -836,6 +856,16 @@ class DeviceListUpdater: ) return + # Check if we are partially joining any rooms. If so we need to store + # all device list updates so that we can handle them correctly once we + # know who is in the room. + partial_rooms = await self.store.get_partial_state_rooms_and_servers() + if partial_rooms: + await self.store.add_remote_device_list_to_pending( + user_id, + device_id, + ) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we @@ -1175,3 +1205,35 @@ class DeviceListUpdater: device_ids.append(verify_key.version) return device_ids + + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + pending_updates = ( + await self.store.get_pending_remote_device_list_updates_for_room(room_id) + ) + + for user_id, device_id in pending_updates: + logger.info( + "Got pending device list update in room %s: %s / %s", + room_id, + user_id, + device_id, + ) + position = await self.store.add_device_change_to_streams( + user_id, + [device_id], + room_ids=[room_id], + ) + + if not position: + # This should only happen if there are no updates, which + # shouldn't happen when we've passed in a non-empty set of + # device IDs. + continue + + self.device_handler.notifier.on_new_event( + StreamKeyType.DEVICE_LIST, position, rooms=[room_id] + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8f847ff845..360ab6fee2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -149,6 +149,7 @@ class FederationHandler: self.http_client = hs.get_proxied_blacklisted_http_client() self._replication = hs.get_replication_data_handler() self._federation_event_handler = hs.get_federation_event_handler() + self._device_handler = hs.get_device_handler() self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( @@ -1631,6 +1632,9 @@ class FederationHandler: # https://github.com/matrix-org/synapse/issues/12994 await self.state_handler.update_current_state(room_id) + logger.info("Handling any pending device list updates") + await self._device_handler.handle_room_un_partial_stated(room_id) + logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) if success: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1151fb0cc3..1e562d4a40 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1995,3 +1995,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): add_device_list_outbound_pokes_txn, stream_ids, ) + + async def add_remote_device_list_to_pending( + self, user_id: str, device_id: str + ) -> None: + """Add a device list update to the table tracking remote device list + updates during partial joins. + """ + + async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + await self.db_pool.simple_upsert( + table="device_lists_remote_pending", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={"stream_id": stream_id}, + desc="add_remote_device_list_to_pending", + ) + + async def get_pending_remote_device_list_updates_for_room( + self, room_id: str + ) -> Collection[Tuple[str, str]]: + """Get the set of remote device list updates from the pending table for + the room. + """ + + min_device_stream_id = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={ + "room_id": room_id, + }, + retcol="device_lists_stream_id", + desc="get_pending_remote_device_list_updates_for_room_device", + ) + + sql = """ + SELECT user_id, device_id FROM device_lists_remote_pending AS d + INNER JOIN current_state_events AS c ON + type = 'm.room.member' + AND state_key = user_id + AND membership = 'join' + WHERE + room_id = ? AND stream_id > ? + """ + + def get_pending_remote_device_list_updates_for_room_txn( + txn: LoggingTransaction, + ) -> Collection[Tuple[str, str]]: + txn.execute(sql, (room_id, min_device_stream_id)) + return cast(Collection[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_pending_remote_device_list_updates_for_room", + get_pending_remote_device_list_updates_for_room_txn, + ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 064c332fb7..672c9a03fc 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1217,6 +1217,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + # We now delete anything from `device_lists_remote_pending` with a + # stream ID less than the minimum + # `partial_state_rooms.device_lists_stream_id`, as we no longer need them. + device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn( + txn, + table="partial_state_rooms", + keyvalues={}, + retcol="MIN(device_lists_stream_id)", + allow_none=True, + ) + if device_lists_stream_id is None: + # There are no rooms being currently partially joined, so we delete everything. + txn.execute("DELETE FROM device_lists_remote_pending") + else: + sql = """ + DELETE FROM device_lists_remote_pending + WHERE stream_id <= ? + """ + txn.execute(sql, (device_lists_stream_id,)) + @cached() async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. diff --git a/synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql b/synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql new file mode 100644 index 0000000000..dbd78d677d --- /dev/null +++ b/synapse/storage/schema/main/delta/73/04pending_device_list_updates.sql @@ -0,0 +1,28 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores remote device lists we have received for remote users while a partial +-- join is in progress. +-- +-- This allows us to replay any device list updates if it turns out the remote +-- user was in the partially joined room +CREATE TABLE device_lists_remote_pending( + stream_id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +-- We only keep the most recent update for a given user/device pair. +CREATE UNIQUE INDEX device_lists_remote_pending_user_device_id ON device_lists_remote_pending(user_id, device_id); -- cgit 1.5.1 From 7766bd5b354cd4ea1a33351ba320e54a14d3aeac Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 28 Sep 2022 10:58:25 -0400 Subject: Stop returning an unused column when handling new receipts. (#13933) --- changelog.d/13933.feature | 1 + synapse/storage/databases/main/event_push_actions.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13933.feature (limited to 'synapse') diff --git a/changelog.d/13933.feature b/changelog.d/13933.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13933.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f4cdc2e399..3e4b4485d6 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1053,7 +1053,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = """ - SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering + SELECT r.room_id, r.user_id, e.stream_ordering FROM receipts_linearized AS r INNER JOIN events AS e USING (event_id) WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? @@ -1078,7 +1078,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # For each new read receipt we delete push actions from before it and # recalculate the summary. - for _, room_id, user_id, stream_ordering in rows: + for room_id, user_id, stream_ordering in rows: # Only handle our own read receipts. if not self.hs.is_mine_id(user_id): continue -- cgit 1.5.1 From 1386ce4735019ea6e918591509ee58a82c9c635c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 28 Sep 2022 11:01:41 -0400 Subject: Revert "Stop returning an unused column when handling new receipts. (#13933)" (#13935) This reverts commit 7766bd5b354cd4ea1a33351ba320e54a14d3aeac (#13933). The unused column is actually used, but much further down in the function. --- changelog.d/13933.feature | 1 - synapse/storage/databases/main/event_push_actions.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) delete mode 100644 changelog.d/13933.feature (limited to 'synapse') diff --git a/changelog.d/13933.feature b/changelog.d/13933.feature deleted file mode 100644 index d0cb902dff..0000000000 --- a/changelog.d/13933.feature +++ /dev/null @@ -1 +0,0 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3e4b4485d6..f4cdc2e399 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1053,7 +1053,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = """ - SELECT r.room_id, r.user_id, e.stream_ordering + SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering FROM receipts_linearized AS r INNER JOIN events AS e USING (event_id) WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? @@ -1078,7 +1078,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # For each new read receipt we delete push actions from before it and # recalculate the summary. - for room_id, user_id, stream_ordering in rows: + for _, room_id, user_id, stream_ordering in rows: # Only handle our own read receipts. if not self.hs.is_mine_id(user_id): continue -- cgit 1.5.1 From df8b91ed2bba4995c59a5b067e3b252ab90c9a5e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 28 Sep 2022 15:26:16 -0500 Subject: Limit and filter the number of backfill points to get from the database (#13879) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There is no need to grab thousands of backfill points when we only need 5 to make the `/backfill` request with. We need to grab a few extra in case the first few aren't visible in the history. Previously, we grabbed thousands of backfill points from the database, then sorted and filtered them in the app. Fetching the 4.6k backfill points for `#matrix:matrix.org` from the database takes ~50ms - ~570ms so it's not like this saves a lot of time 🤷. But it might save us more time now that `get_backfill_points_in_room`/`get_insertion_event_backward_extremities_in_room` are more complicated after https://github.com/matrix-org/synapse/pull/13635 This PR moves the filtering and limiting to the SQL query so we just have less data to work with in the first place. Part of https://github.com/matrix-org/synapse/issues/13356 --- changelog.d/13879.misc | 1 + synapse/handlers/federation.py | 109 ++++++++++++--------- synapse/storage/databases/main/event_federation.py | 90 ++++++++++++++--- tests/storage/test_event_federation.py | 80 ++++++++++----- 4 files changed, 198 insertions(+), 82 deletions(-) create mode 100644 changelog.d/13879.misc (limited to 'synapse') diff --git a/changelog.d/13879.misc b/changelog.d/13879.misc new file mode 100644 index 0000000000..3cc2a2420f --- /dev/null +++ b/changelog.d/13879.misc @@ -0,0 +1 @@ +Only pull relevant backfill points from the database based on the current depth and limit (instead of all) every time we want to `/backfill`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 360ab6fee2..500c1c16d0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -38,7 +38,7 @@ from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 from synapse import event_auth -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, CodeMessageException, @@ -211,7 +211,7 @@ class FederationHandler: current_depth: int, limit: int, *, - processing_start_time: int, + processing_start_time: Optional[int], ) -> bool: """ Checks whether the `current_depth` is at or approaching any backfill @@ -223,12 +223,23 @@ class FederationHandler: room_id: The room to backfill in. current_depth: The depth to check at for any upcoming backfill points. limit: The max number of events to request from the remote federated server. - processing_start_time: The time when `maybe_backfill` started - processing. Only used for timing. + processing_start_time: The time when `maybe_backfill` started processing. + Only used for timing. If `None`, no timing observation will be made. """ backwards_extremities = [ _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) - for event_id, depth in await self.store.get_backfill_points_in_room(room_id) + for event_id, depth in await self.store.get_backfill_points_in_room( + room_id=room_id, + current_depth=current_depth, + # We only need to end up with 5 extremities combined with the + # insertion event extremities to make the `/backfill` request + # but fetch an order of magnitude more to make sure there is + # enough even after we filter them by whether visible in the + # history. This isn't fool-proof as all backfill points within + # our limit could be filtered out but seems like a good amount + # to try with at least. + limit=50, + ) ] insertion_events_to_be_backfilled: List[_BackfillPoint] = [] @@ -236,7 +247,12 @@ class FederationHandler: insertion_events_to_be_backfilled = [ _BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT) for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room( - room_id + room_id=room_id, + current_depth=current_depth, + # We only need to end up with 5 extremities combined with + # the backfill points to make the `/backfill` request ... + # (see the other comment above for more context). + limit=50, ) ] logger.debug( @@ -245,10 +261,6 @@ class FederationHandler: insertion_events_to_be_backfilled, ) - if not backwards_extremities and not insertion_events_to_be_backfilled: - logger.debug("Not backfilling as no extremeties found.") - return False - # we now have a list of potential places to backpaginate from. We prefer to # start with the most recent (ie, max depth), so let's sort the list. sorted_backfill_points: List[_BackfillPoint] = sorted( @@ -269,6 +281,33 @@ class FederationHandler: sorted_backfill_points, ) + # If we have no backfill points lower than the `current_depth` then + # either we can a) bail or b) still attempt to backfill. We opt to try + # backfilling anyway just in case we do get relevant events. + if not sorted_backfill_points and current_depth != MAX_DEPTH: + logger.debug( + "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points." + ) + return await self._maybe_backfill_inner( + room_id=room_id, + # We use `MAX_DEPTH` so that we find all backfill points next + # time (all events are below the `MAX_DEPTH`) + current_depth=MAX_DEPTH, + limit=limit, + # We don't want to start another timing observation from this + # nested recursive call. The top-most call can record the time + # overall otherwise the smaller one will throw off the results. + processing_start_time=None, + ) + + # Even after recursing with `MAX_DEPTH`, we didn't find any + # backward extremities to backfill from. + if not sorted_backfill_points: + logger.debug( + "_maybe_backfill_inner: Not backfilling as no backward extremeties found." + ) + return False + # If we're approaching an extremity we trigger a backfill, otherwise we # no-op. # @@ -278,47 +317,16 @@ class FederationHandler: # chose more than one times the limit in case of failure, but choosing a # much larger factor will result in triggering a backfill request much # earlier than necessary. - # - # XXX: shouldn't we do this *after* the filter by depth below? Again, we don't - # care about events that have happened after our current position. - # - max_depth = sorted_backfill_points[0].depth - if current_depth - 2 * limit > max_depth: + max_depth_of_backfill_points = sorted_backfill_points[0].depth + if current_depth - 2 * limit > max_depth_of_backfill_points: logger.debug( "Not backfilling as we don't need to. %d < %d - 2 * %d", - max_depth, + max_depth_of_backfill_points, current_depth, limit, ) return False - # We ignore extremities that have a greater depth than our current depth - # as: - # 1. we don't really care about getting events that have happened - # after our current position; and - # 2. we have likely previously tried and failed to backfill from that - # extremity, so to avoid getting "stuck" requesting the same - # backfill repeatedly we drop those extremities. - # - # However, we need to check that the filtered extremities are non-empty. - # If they are empty then either we can a) bail or b) still attempt to - # backfill. We opt to try backfilling anyway just in case we do get - # relevant events. - # - filtered_sorted_backfill_points = [ - t for t in sorted_backfill_points if t.depth <= current_depth - ] - if filtered_sorted_backfill_points: - logger.debug( - "_maybe_backfill_inner: backfill points before current depth: %s", - filtered_sorted_backfill_points, - ) - sorted_backfill_points = filtered_sorted_backfill_points - else: - logger.debug( - "_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway." - ) - # For performance's sake, we only want to paginate from a particular extremity # if we can actually see the events we'll get. Otherwise, we'd just spend a lot # of resources to get redacted events. We check each extremity in turn and @@ -452,10 +460,15 @@ class FederationHandler: return False - processing_end_time = self.clock.time_msec() - backfill_processing_before_timer.observe( - (processing_end_time - processing_start_time) / 1000 - ) + # If we have the `processing_start_time`, then we can make an + # observation. We wouldn't have the `processing_start_time` in the case + # where `_maybe_backfill_inner` is recursively called to find any + # backfill points regardless of `current_depth`. + if processing_start_time is not None: + processing_end_time = self.clock.time_msec() + backfill_processing_before_timer.observe( + (processing_end_time - processing_start_time) / 1000 + ) success = await try_backfill(likely_domains) if success: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 3251fca6fb..17f2fd4458 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -726,17 +726,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas async def get_backfill_points_in_room( self, room_id: str, + current_depth: int, + limit: int, ) -> List[Tuple[str, int]]: """ - Gets the oldest events(backwards extremities) in the room along with the - approximate depth. Sorted by depth, highest to lowest (descending). + Get the backward extremities to backfill from in the room along with the + approximate depth. + + Only returns events that are at a depth lower than or + equal to the `current_depth`. Sorted by depth, highest to lowest (descending) + so the closest events to the `current_depth` are first in the list. + + We ignore extremities that are newer than the user's current scroll position + (ie, those with depth greater than `current_depth`) as: + 1. we don't really care about getting events that have happened + after our current position; and + 2. by the nature of paginating and scrolling back, we have likely + previously tried and failed to backfill from that extremity, so + to avoid getting "stuck" requesting the same backfill repeatedly + we drop those extremities. Args: room_id: Room where we want to find the oldest events + current_depth: The depth at the user's current scrollback position + limit: The max number of backfill points to return Returns: List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) + (descending) so the closest events to the `current_depth` are first + in the list. """ def get_backfill_points_in_room_txn( @@ -784,6 +802,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas * necessarily safe to assume that it will have been completed. */ AND edge.is_state is ? /* False */ + /** + * We only want backwards extremities that are older than or at + * the same position of the given `current_depth` (where older + * means less than the given depth) because we're looking backwards + * from the `current_depth` when backfilling. + * + * current_depth (ignore events that come after this, ignore 2-4) + * | + * ▼ + * [0]<--[1]<--[2]<--[3]<--[4] + */ + AND event.depth <= ? /* current_depth */ /** * Exponential back-off (up to the upper bound) so we don't retry the * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. @@ -798,11 +828,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) ) /** - * Sort from highest to the lowest depth. Then tie-break on - * alphabetical order of the event_ids so we get a consistent - * ordering which is nice when asserting things in tests. + * Sort from highest (closest to the `current_depth`) to the lowest depth + * because the closest are most relevant to backfill from first. + * Then tie-break on alphabetical order of the event_ids so we get a + * consistent ordering which is nice when asserting things in tests. */ ORDER BY event.depth DESC, backward_extrem.event_id DESC + LIMIT ? """ if isinstance(self.database_engine, PostgresEngine): @@ -817,9 +849,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ( room_id, False, + current_depth, self._clock.time_msec(), 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + limit, ), ) @@ -835,18 +869,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas async def get_insertion_event_backward_extremities_in_room( self, room_id: str, + current_depth: int, + limit: int, ) -> List[Tuple[str, int]]: """ Get the insertion events we know about that we haven't backfilled yet - along with the approximate depth. Sorted by depth, highest to lowest - (descending). + along with the approximate depth. Only returns insertion events that are + at a depth lower than or equal to the `current_depth`. Sorted by depth, + highest to lowest (descending) so the closest events to the + `current_depth` are first in the list. + + We ignore insertion events that are newer than the user's current scroll + position (ie, those with depth greater than `current_depth`) as: + 1. we don't really care about getting events that have happened + after our current position; and + 2. by the nature of paginating and scrolling back, we have likely + previously tried and failed to backfill from that insertion event, so + to avoid getting "stuck" requesting the same backfill repeatedly + we drop those insertion event. Args: room_id: Room where we want to find the oldest events + current_depth: The depth at the user's current scrollback position + limit: The max number of insertion event extremities to return Returns: List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) + (descending) so the closest events to the `current_depth` are first + in the list. """ def get_insertion_event_backward_extremities_in_room_txn( @@ -869,6 +919,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id WHERE insertion_event_extremity.room_id = ? + /** + * We only want extremities that are older than or at + * the same position of the given `current_depth` (where older + * means less than the given depth) because we're looking backwards + * from the `current_depth` when backfilling. + * + * current_depth (ignore events that come after this, ignore 2-4) + * | + * ▼ + * [0]<--[1]<--[2]<--[3]<--[4] + */ + AND event.depth <= ? /* current_depth */ /** * Exponential back-off (up to the upper bound) so we don't retry the * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc @@ -883,11 +945,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) ) /** - * Sort from highest to the lowest depth. Then tie-break on - * alphabetical order of the event_ids so we get a consistent - * ordering which is nice when asserting things in tests. + * Sort from highest (closest to the `current_depth`) to the lowest depth + * because the closest are most relevant to backfill from first. + * Then tie-break on alphabetical order of the event_ids so we get a + * consistent ordering which is nice when asserting things in tests. */ ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC + LIMIT ? """ if isinstance(self.database_engine, PostgresEngine): @@ -901,9 +965,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas sql % (least_function,), ( room_id, + current_depth, self._clock.time_msec(), 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + limit, ), ) return cast(List[Tuple[str, int]], txn.fetchall()) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 85739c464e..398f338b66 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -754,19 +754,31 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_backfill_points_in_room(self): """ - Test to make sure we get some backfill points + Test to make sure only backfill points that are older and come before + the `current_depth` are returned. """ setup_info = self._setup_room_for_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map + # Try at "B" backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertListEqual( backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] ) + # Try at "A" + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Event "2" has a depth of 2 but is not included here because we only + # know the approximate depth of 5 from our event "3". + self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"]) + def test_get_backfill_points_in_room_excludes_events_we_have_attempted( self, ): @@ -776,6 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_backfill_points_in_room` exclude them because we @@ -795,8 +808,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # No time has passed since we attempted to backfill ^ + # Try at "B" backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. @@ -812,6 +826,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_backfill_points_in_room` exclude them because we @@ -839,26 +854,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # visible regardless. self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) - # Make sure that "b1" is not in the list because we've + # Try at "A" and make sure that "b1" is not in the list because we've # already attempted many times backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"]) + self.assertListEqual(backfill_event_ids, ["b3", "b2"]) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) - # Try again after we advanced enough time and we should see "b3" again + # Try at "A" again after we advanced enough time and we should see "b3" again backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] - ) + self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"]) def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: """ @@ -938,19 +951,35 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_insertion_event_backward_extremities_in_room(self): """ - Test to make sure insertion event backward extremities are returned. + Test to make sure only insertion event backward extremities that are + older and come before the `current_depth` are returned. """ setup_info = self._setup_room_for_insertion_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map + # Try at "insertion_eventB" backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventB"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertListEqual( backfill_event_ids, ["insertion_eventB", "insertion_eventA"] ) + # Try at "insertion_eventA" + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventA"], limit=100 + ) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Event "2" has a depth of 2 but is not included here because we only + # know the approximate depth of 5 from our event "3". + self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) + def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( self, ): @@ -961,6 +990,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_insertion_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_insertion_event_backward_extremities_in_room` exclude them @@ -973,8 +1003,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # No time has passed since we attempted to backfill ^ + # Try at "insertion_eventB" backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventB"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. @@ -991,6 +1024,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_insertion_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_backfill_points_in_room` exclude them because we @@ -1027,13 +1061,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # because we haven't waited long enough for this many attempts. self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) - # Make sure that "insertion_eventA" is not in the list because we've - # already attempted many times + # Try at "insertion_eventA" and make sure that "insertion_eventA" is not + # in the list because we've already attempted many times backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventA"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + self.assertListEqual(backfill_event_ids, []) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it @@ -1042,12 +1078,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Try at "insertion_eventA" again after we advanced enough time and we # should see "insertion_eventA" again backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventA"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["insertion_eventB", "insertion_eventA"] - ) + self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) @attr.s -- cgit 1.5.1 From 5f659d4a88e602ca8519984808dcf4df036c781b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Sep 2022 23:22:35 +0100 Subject: Handle local device list updates during partial join (#13934) --- changelog.d/13934.misc | 1 + synapse/handlers/device.py | 84 ++++++++++++++++++++++++++++++- synapse/storage/databases/main/devices.py | 55 +++++++++++++++----- synapse/storage/databases/main/room.py | 16 ++++++ 4 files changed, 141 insertions(+), 15 deletions(-) create mode 100644 changelog.d/13934.misc (limited to 'synapse') diff --git a/changelog.d/13934.misc b/changelog.d/13934.misc new file mode 100644 index 0000000000..6610a9f567 --- /dev/null +++ b/changelog.d/13934.misc @@ -0,0 +1 @@ +Correctly handle sending local device list updates to remote servers during a partial join. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f2ef591103..03082fce42 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -762,10 +762,90 @@ class DeviceHandler(DeviceWorkerHandler): gone from partial to full state. """ - # We defer to the device list updater implementation as we're on the - # right worker. + # We defer to the device list updater to handle pending remote device + # list updates. await self.device_list_updater.handle_room_un_partial_stated(room_id) + # Replay local updates. + ( + join_event_id, + device_lists_stream_id, + ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state( + room_id + ) + + # Get the local device list changes that have happened in the room since + # we started joining. If there are no updates there's nothing left to do. + changes = await self.store.get_device_list_changes_in_room( + room_id, device_lists_stream_id + ) + local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)} + if not local_changes: + return + + # Note: We have persisted the full state at this point, we just haven't + # cleared the `partial_room` flag. + join_state_ids = await self._state_storage.get_state_ids_for_event( + join_event_id, await_full_state=False + ) + current_state_ids = await self.store.get_partial_current_state_ids(room_id) + + # Now we need to work out all servers that might have been in the room + # at any point during our join. + + # First we look for any membership states that have changed between the + # initial join and now... + all_keys = set(join_state_ids) + all_keys.update(current_state_ids) + + potentially_changed_hosts = set() + for etype, state_key in all_keys: + if etype != EventTypes.Member: + continue + + prev = join_state_ids.get((etype, state_key)) + current = current_state_ids.get((etype, state_key)) + + if prev != current: + potentially_changed_hosts.add(get_domain_from_id(state_key)) + + # ... then we add all the hosts that are currently joined to the room... + current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id) + potentially_changed_hosts.update(current_hosts_in_room) + + # ... and finally we remove any hosts that we were told about, as we + # will have sent device list updates to those hosts when they happened. + known_hosts_at_join = await self.store.get_partial_state_servers_at_join( + room_id + ) + potentially_changed_hosts.difference_update(known_hosts_at_join) + + potentially_changed_hosts.discard(self.server_name) + + if not potentially_changed_hosts: + # Nothing to do. + return + + logger.info( + "Found %d changed hosts to send device list updates to", + len(potentially_changed_hosts), + ) + + for user_id, device_id in local_changes: + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=None, + hosts=potentially_changed_hosts, + context=None, + ) + + # Notify things that device lists need to be sent out. + self.notifier.notify_replication() + for host in potentially_changed_hosts: + self.federation_sender.send_device_messages(host, immediate=False) + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1e562d4a40..18358eca46 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return changes + async def get_device_list_changes_in_room( + self, room_id: str, min_stream_id: int + ) -> Collection[Tuple[str, str]]: + """Get all device list changes that happened in the room since the given + stream ID. + + Returns: + Collection of user ID/device ID tuples of all devices that have + changed + """ + + sql = """ + SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room + WHERE room_id = ? AND stream_id > ? + """ + + def get_device_list_changes_in_room_txn( + txn: LoggingTransaction, + ) -> Collection[Tuple[str, str]]: + txn.execute(sql, (room_id, min_stream_id)) + return cast(Collection[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_device_list_changes_in_room", + get_device_list_changes_in_room_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: int, + stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - Marks the associated row in `device_lists_changes_in_room` as handled. + Marks the associated row in `device_lists_changes_in_room` as handled, + if `stream_id` is provided. """ def add_device_list_outbound_pokes_txn( @@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context=context, ) - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) + if stream_id: + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) if not hosts: # If there are no hosts then we don't try and generate stream IDs. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 672c9a03fc..059eef5c22 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1256,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return entry is not None + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( + self, room_id: str + ) -> Tuple[str, int]: + """Get the event ID of the initial join that started the partial + join, and the device list stream ID at the point we started the partial + join. + """ + + result = await self.db_pool.simple_select_one( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcols=("join_event_id", "device_lists_stream_id"), + desc="get_join_event_id_for_partial_state", + ) + return result["join_event_id"], result["device_lists_stream_id"] + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" -- cgit 1.5.1 From 73ecff7e9ed456c64368296858d17d4b393c9f9a Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 29 Sep 2022 10:00:02 +0000 Subject: Improve backfill robustness by trying more servers. (#13890) Co-authored-by: Eric Eastwood --- changelog.d/13890.misc | 1 + synapse/handlers/federation.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13890.misc (limited to 'synapse') diff --git a/changelog.d/13890.misc b/changelog.d/13890.misc new file mode 100644 index 0000000000..bf76cf7be7 --- /dev/null +++ b/changelog.d/13890.misc @@ -0,0 +1 @@ +Improve backfill robustness by trying more servers when we get a `4xx` error back. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 500c1c16d0..b866258298 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -417,6 +417,15 @@ class FederationHandler: async def try_backfill(domains: Collection[str]) -> bool: # TODO: Should we try multiple of these at a time? + + # Number of contacted remote homeservers that have denied our backfill + # request with a 4xx code. + denied_count = 0 + + # Maximum number of contacted remote homeservers that can deny our + # backfill request with 4xx codes before we give up. + max_denied_count = 5 + for dom in domains: # We don't want to ask our own server for information we don't have if dom == self.server_name: @@ -435,13 +444,33 @@ class FederationHandler: continue except HttpResponseException as e: if 400 <= e.code < 500: - raise e.to_synapse_error() + logger.warning( + "Backfill denied from %s because %s [%d/%d]", + dom, + e, + denied_count, + max_denied_count, + ) + denied_count += 1 + if denied_count >= max_denied_count: + return False + continue logger.info("Failed to backfill from %s because %s", dom, e) continue except CodeMessageException as e: if 400 <= e.code < 500: - raise + logger.warning( + "Backfill denied from %s because %s [%d/%d]", + dom, + e, + denied_count, + max_denied_count, + ) + denied_count += 1 + if denied_count >= max_denied_count: + return False + continue logger.info("Failed to backfill from %s because %s", dom, e) continue -- cgit 1.5.1 From 99a7e7e0230cba5d00ec204926edae89d4b6b8c3 Mon Sep 17 00:00:00 2001 From: Nicolas Werner <89468146+nico-famedly@users.noreply.github.com> Date: Thu, 29 Sep 2022 10:57:00 +0000 Subject: Always send default and rule_id to clients (#13904) --- changelog.d/13904.bugfix | 1 + synapse/push/clientformat.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) create mode 100644 changelog.d/13904.bugfix (limited to 'synapse') diff --git a/changelog.d/13904.bugfix b/changelog.d/13904.bugfix new file mode 100644 index 0000000000..397a3108ac --- /dev/null +++ b/changelog.d/13904.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.66 where some required fields in the pushrules sent to clients were not present anymore. Contributed by Nico. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index ebc13beda1..7095ae83f9 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -102,10 +102,8 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]: # with PRIORITY_CLASS_INVERSE_MAP. raise ValueError("Unexpected template_name: %s" % (template_name,)) - if unscoped_rule_id: - templaterule["rule_id"] = unscoped_rule_id - if rule.default: - templaterule["default"] = True + templaterule["rule_id"] = unscoped_rule_id + templaterule["default"] = rule.default return templaterule -- cgit 1.5.1 From 568016929f3d22f632cb9145429fa45754a8d59f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Sep 2022 07:07:31 -0400 Subject: Clarify that a method returns only unthreaded receipts. (#13937) By renaming it and updating the docstring. Additionally, refactors a method which is used only by tests. --- changelog.d/13937.feature | 1 + .../storage/databases/main/event_push_actions.py | 12 +--- synapse/storage/databases/main/receipts.py | 36 ++--------- tests/storage/test_receipts.py | 74 +++++++++++----------- 4 files changed, 47 insertions(+), 76 deletions(-) create mode 100644 changelog.d/13937.feature (limited to 'synapse') diff --git a/changelog.d/13937.feature b/changelog.d/13937.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13937.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f4cdc2e399..7e0ffef7d3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, ) -> NotifCounts: # Get the stream ordering of the user's latest receipt in the room. - result = self.get_last_receipt_for_user_txn( + result = self.get_last_unthreaded_receipt_for_user_txn( txn, user_id, room_id, - receipt_types=( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ), + receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) if result: @@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ), + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) sql = f""" diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 52fe0db924..246f78ac1f 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() - async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_types: Collection[str] - ) -> Optional[str]: - """ - Fetch the event ID for the latest receipt in a room with one of the given receipt types. - - Args: - user_id: The user to fetch receipts for. - room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. - - Returns: - The latest receipt, if one exists. - """ - result = await self.db_pool.runInteraction( - "get_last_receipt_event_id_for_user", - self.get_last_receipt_for_user_txn, - user_id, - room_id, - receipt_types, - ) - if not result: - return None - - event_id, _ = result - return event_id - - def get_last_receipt_for_user_txn( + def get_last_unthreaded_receipt_for_user_txn( self, txn: LoggingTransaction, user_id: str, @@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ - Fetch the event ID and stream_ordering for the latest receipt in a room - with one of the given receipt types. + Fetch the event ID and stream_ordering for the latest unthreaded receipt + in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. + receipt_types: The receipt types to fetch. Returns: The event ID and stream ordering of the latest receipt, if one exists. @@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore): WHERE {clause} AND user_id = ? AND room_id = ? + AND thread_id IS NULL ORDER BY stream_ordering DESC LIMIT 1 """ diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 9459ee1705..81253d0361 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, Optional from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase): ) ) + def get_last_unthreaded_receipt( + self, receipt_types: Collection[str], room_id: Optional[str] = None + ) -> Optional[str]: + """ + Fetch the event ID for the latest unthreaded receipt in the test room for the test user. + + Args: + receipt_types: The receipt types to fetch. + + Returns: + The latest receipt, if one exists. + """ + result = self.get_success( + self.store.db_pool.runInteraction( + "get_last_receipt_event_id_for_user", + self.store.get_last_unthreaded_receipt_for_user_txn, + OUR_USER_ID, + room_id or self.room_id1, + receipt_types, + ) + ) + if not result: + return None + + event_id, _ = result + return event_id + def test_return_empty_with_no_data(self) -> None: res = self.get_success( self.store.get_receipts_for_user( @@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase): ) self.assertEqual(res, {}) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id1, - [ - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) + self.assertEqual(res, None) def test_get_receipts_for_user(self) -> None: @@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase): ) # Test we get the latest event when we want both private and public receipts - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) self.assertEqual(res, event1_2_id) # Test we get the older event when we want only public receipt - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ]) self.assertEqual(res, event1_1_id) # Test we get the latest event when we want only the private receipt - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE]) self.assertEqual(res, event1_2_id) # Test receipt updating @@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase): self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ]) self.assertEqual(res, event1_2_id) # Send some events into the second room @@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase): {}, ) ) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2 ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From e5fdf16d4680b00ca8120ddb697bd14ab89fdf0c Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Thu, 29 Sep 2022 12:22:27 +0100 Subject: Expose MSC3882 only be under an unstable endpoint. (#13868) --- changelog.d/13868.misc | 1 + synapse/rest/client/login_token_request.py | 4 +++- tests/rest/client/test_login_token_request.py | 16 +++++++++------- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 changelog.d/13868.misc (limited to 'synapse') diff --git a/changelog.d/13868.misc b/changelog.d/13868.misc new file mode 100644 index 0000000000..d7a99c042a --- /dev/null +++ b/changelog.d/13868.misc @@ -0,0 +1 @@ +Fix unstable MSC3882 endpoint being incorrectly available on stable API versions. \ No newline at end of file diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index ca5c54bf17..277b20fb63 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -47,7 +47,9 @@ class LoginTokenRequestServlet(RestServlet): } """ - PATTERNS = client_patterns("/login/token$") + PATTERNS = client_patterns( + "/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True + ) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index d5bb16c98d..c2e1e08811 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -22,6 +22,8 @@ from synapse.util import Clock from tests import unittest from tests.unittest import override_config +endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" + class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): @@ -45,18 +47,18 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.password = "password" def test_disabled(self) -> None: - channel = self.make_request("POST", "/login/token", {}, access_token=None) + channel = self.make_request("POST", endpoint, {}, access_token=None) self.assertEqual(channel.code, 400) self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 400) @override_config({"experimental_features": {"msc3882_enabled": True}}) def test_require_auth(self) -> None: - channel = self.make_request("POST", "/login/token", {}, access_token=None) + channel = self.make_request("POST", endpoint, {}, access_token=None) self.assertEqual(channel.code, 401) @override_config({"experimental_features": {"msc3882_enabled": True}}) @@ -64,7 +66,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): user_id = self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 401) self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) @@ -79,7 +81,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): }, } - channel = self.make_request("POST", "/login/token", uia, access_token=token) + channel = self.make_request("POST", endpoint, uia, access_token=token) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 300) @@ -100,7 +102,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): user_id = self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 300) @@ -127,6 +129,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", "/login/token", {}, access_token=token) + channel = self.make_request("POST", endpoint, {}, access_token=token) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 15) -- cgit 1.5.1 From 8625ad80994d6049a778b5d1ef65c8d1b1042c74 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Sep 2022 07:22:41 -0400 Subject: Explicit cast to enforce type hints. (#13939) --- changelog.d/13939.feature | 1 + synapse/storage/databases/main/event_push_actions.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/13939.feature (limited to 'synapse') diff --git a/changelog.d/13939.feature b/changelog.d/13939.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13939.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7e0ffef7d3..3fdf128d9e 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1068,7 +1068,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas limit, ), ) - rows = txn.fetchall() + rows = cast(List[Tuple[int, str, str, int]], txn.fetchall()) # For each new read receipt we delete push actions from before it and # recalculate the summary. @@ -1113,18 +1113,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # We always update `event_push_summary_last_receipt_stream_id` to # ensure that we don't rescan the same receipts for remote users. - upper_limit = max_receipts_stream_id + receipts_last_processed_stream_id = max_receipts_stream_id if len(rows) >= limit: # If we pulled out a limited number of rows we only update the # position to the last receipt we processed, so we continue # processing the rest next iteration. - upper_limit = rows[-1][0] + receipts_last_processed_stream_id = rows[-1][0] self.db_pool.simple_update_txn( txn, table="event_push_summary_last_receipt_stream_id", keyvalues={}, - updatevalues={"stream_id": upper_limit}, + updatevalues={"stream_id": receipts_last_processed_stream_id}, ) return len(rows) < limit -- cgit 1.5.1 From be76cd8200b18f3c68b895f85ac7ef5b0ddc2466 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 29 Sep 2022 14:23:24 +0100 Subject: Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556) --- changelog.d/13556.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 +- synapse/api/constants.py | 11 ++ synapse/api/errors.py | 16 ++ synapse/config/experimental.py | 19 +++ synapse/handlers/admin.py | 5 + synapse/handlers/auth.py | 11 ++ synapse/handlers/register.py | 8 + synapse/replication/http/register.py | 5 + synapse/rest/admin/users.py | 43 ++++- synapse/rest/client/login.py | 37 +++- synapse/rest/client/register.py | 22 ++- synapse/storage/databases/main/__init__.py | 9 +- synapse/storage/databases/main/registration.py | 150 +++++++++++++++-- .../main/delta/73/03users_approved_column.sql | 20 +++ tests/rest/admin/test_user.py | 186 ++++++++++++++++++++- tests/rest/client/test_auth.py | 33 +++- tests/rest/client/test_login.py | 41 +++++ tests/rest/client/test_register.py | 32 +++- tests/rest/client/utils.py | 12 +- tests/storage/test_registration.py | 102 ++++++++++- 21 files changed, 731 insertions(+), 34 deletions(-) create mode 100644 changelog.d/13556.feature create mode 100644 synapse/storage/schema/main/delta/73/03users_approved_column.sql (limited to 'synapse') diff --git a/changelog.d/13556.feature b/changelog.d/13556.feature new file mode 100644 index 0000000000..f9d63db6c0 --- /dev/null +++ b/changelog.d/13556.feature @@ -0,0 +1 @@ +Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 450ba462ba..5fa599e70e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], - "users": ["shadow_banned"], + "users": ["shadow_banned", "approved"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c178ddf070..c031903b1a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -269,3 +269,14 @@ class PublicRoomsFilterFields: GENERIC_SEARCH_TERM: Final = "generic_search_term" ROOM_TYPES: Final = "room_types" + + +class ApprovalNoticeMedium: + """Identifier for the medium this server will use to serve notice of approval for a + specific user's registration. + + As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md + """ + + NONE = "org.matrix.msc3866.none" + EMAIL = "org.matrix.msc3866.email" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 1c6b53aa24..c606207569 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -106,6 +106,8 @@ class Codes(str, Enum): # Part of MSC3895. UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE" + USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL" + class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. @@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError): return cs_error(self.msg, self.errcode, **extra) +class NotApprovedError(SynapseError): + def __init__( + self, + msg: str, + approval_notice_medium: str, + ): + super().__init__( + code=403, + msg=msg, + errcode=Codes.USER_AWAITING_APPROVAL, + additional_fields={"approval_notice_medium": approval_notice_medium}, + ) + + def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": """Utility method for constructing an error response for client-server interactions. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 933779c23a..31834fb27d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -14,10 +14,25 @@ from typing import Any +import attr + from synapse.config._base import Config from synapse.types import JsonDict +@attr.s(auto_attribs=True, frozen=True, slots=True) +class MSC3866Config: + """Configuration for MSC3866 (mandating approval for new users)""" + + # Whether the base support for the approval process is enabled. This includes the + # ability for administrators to check and update the approval of users, even if no + # approval is currently required. + enabled: bool = False + # Whether to require that new users are approved by an admin before their account + # can be used. Note that this setting is ignored if 'enabled' is false. + require_approval_for_new_accounts: bool = False + + class ExperimentalConfig(Config): """Config section for enabling experimental features""" @@ -97,6 +112,10 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + # MSC3866: M_USER_AWAITING_APPROVAL error code + raw_msc3866_config = experimental.get("msc3866", {}) + self.msc3866 = MSC3866Config(**raw_msc3866_config) + # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index cf9f19608a..f2989cc4a2 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -32,6 +32,7 @@ class AdminHandler: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -75,6 +76,10 @@ class AdminHandler: "is_guest", } + if self._msc3866_enabled: + # Only include the approved flag if support for MSC3866 is enabled. + user_info_to_return.add("approved") + # Restrict returned keys to a known set. user_info_dict = { key: value diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index eacd631ee0..f5f0e0e7a7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1009,6 +1009,17 @@ class AuthHandler: return res[0] return None + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + return await self.store.is_user_approved(user_id) + async def _find_user_id_and_pwd_hash( self, user_id: str ) -> Optional[Tuple[str, str]]: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cfcadb34db..ca1c7a1866 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -220,6 +220,7 @@ class RegistrationHandler: by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, auth_provider_id: Optional[str] = None, + approved: bool = False, ) -> str: """Registers a new client on the server. @@ -246,6 +247,8 @@ class RegistrationHandler: user_agent_ips: Tuples of user-agents and IP addresses used during the registration process. auth_provider_id: The SSO IdP the user used, if any. + approved: True if the new user should be considered already + approved by an administrator. Returns: The registered user_id. Raises: @@ -307,6 +310,7 @@ class RegistrationHandler: user_type=user_type, address=address, shadow_banned=shadow_banned, + approved=approved, ) profile = await self.store.get_profileinfo(localpart) @@ -695,6 +699,7 @@ class RegistrationHandler: user_type: Optional[str] = None, address: Optional[str] = None, shadow_banned: bool = False, + approved: bool = False, ) -> None: """Register user in the datastore. @@ -713,6 +718,7 @@ class RegistrationHandler: api.constants.UserTypes, or None for a normal user. address: the IP address used to perform the registration. shadow_banned: Whether to shadow-ban the user + approved: Whether to mark the user as approved by an administrator """ if self.hs.config.worker.worker_app: await self._register_client( @@ -726,6 +732,7 @@ class RegistrationHandler: user_type=user_type, address=address, shadow_banned=shadow_banned, + approved=approved, ) else: await self.store.register_user( @@ -738,6 +745,7 @@ class RegistrationHandler: admin=admin, user_type=user_type, shadow_banned=shadow_banned, + approved=approved, ) # Only call the account validity module(s) on the main process, to avoid diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 6c8f8388fd..61abb529c8 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type: Optional[str], address: Optional[str], shadow_banned: bool, + approved: bool, ) -> JsonDict: """ Args: @@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint): or None for a normal user. address: the IP address used to perform the regitration. shadow_banned: Whether to shadow-ban the user + approved: Whether the user should be considered already approved by an + administrator. """ return { "password_hash": password_hash, @@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "user_type": user_type, "address": address, "shadow_banned": shadow_banned, + "approved": approved, } async def _handle_request( # type: ignore[override] @@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type=content["user_type"], address=content["address"], shadow_banned=content["shadow_banned"], + approved=content["approved"], ) return 200, {} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1274773d7e..15ac2059aa 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet): self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) + # If support for MSC3866 is not enabled, apply no filtering based on the + # `approved` column. + if self._msc3866_enabled: + approved = parse_boolean(request, "approved", default=True) + else: + approved = True + order_by = parse_string( request, "order_by", @@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet): direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) users, total = await self.store.get_users_paginate( - start, limit, user_id, name, guests, deactivated, order_by, direction + start, + limit, + user_id, + name, + guests, + deactivated, + order_by, + direction, + approved, ) + + # If support for MSC3866 is not enabled, don't show the approval flag. + if not self._msc3866_enabled: + for user in users: + del user["approved"] + ret = {"users": users, "total": total} if (start + limit) < total: ret["next_token"] = str(start + len(users)) @@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet): self.deactivate_account_handler = hs.get_deactivate_account_handler() self.registration_handler = hs.get_registration_handler() self.pusher_pool = hs.get_pusherpool() + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def on_GET( self, request: SynapseRequest, user_id: str @@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet): HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" ) + approved: Optional[bool] = None + if "approved" in body and self._msc3866_enabled: + approved = body["approved"] + if not isinstance(approved, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'approved' parameter is not of type boolean", + ) + # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: new_external_ids = [ @@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet): if "user_type" in body: await self.store.set_user_type(target_user, user_type) + if approved is not None: + await self.store.update_user_approval_status(target_user, approved) + user = await self.admin_handler.get_user(target_user) assert user is not None @@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet): if password is not None: password_hash = await self.auth_handler.hash(password) + new_user_approved = True + if self._msc3866_enabled and approved is not None: + new_user_approved = approved + user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, @@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet): default_display_name=displayname, user_type=user_type, by_admin=True, + approved=new_user_approved, ) if threepids is not None: @@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet): user_type=user_type, default_display_name=displayname, by_admin=True, + approved=True, ) result = await register._create_registration_details(user_id, body) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0437c87d8d..f554586ac3 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -28,7 +28,14 @@ from typing import ( from typing_extensions import TypedDict -from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError +from synapse.api.constants import ApprovalNoticeMedium +from synapse.api.errors import ( + Codes, + InvalidClientTokenError, + LoginError, + NotApprovedError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService @@ -55,11 +62,11 @@ logger = logging.getLogger(__name__) class LoginResponse(TypedDict, total=False): user_id: str - access_token: str + access_token: Optional[str] home_server: str expires_in_ms: Optional[int] refresh_token: Optional[str] - device_id: str + device_id: Optional[str] well_known: Optional[Dict[str, Any]] @@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet): hs.config.registration.refreshable_access_token_lifetime is not None ) + # Whether we need to check if the user has been approved or not. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet): except KeyError: raise SynapseError(400, "Missing JSON keys.") + if self._require_approval: + approved = await self.auth_handler.is_user_approved(result["user_id"]) + if not approved: + raise NotApprovedError( + msg="This account is pending approval by a server administrator.", + approval_notice_medium=ApprovalNoticeMedium.NONE, + ) + well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data @@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + if self._require_approval: + approved = await self.auth_handler.is_user_approved(user_id) + if not approved: + # If the user isn't approved (and needs to be) we won't allow them to + # actually log in, so we don't want to create a device/access token. + return LoginResponse( + user_id=user_id, + home_server=self.hs.hostname, + ) + initial_display_name = login_submission.get("initial_device_display_name") ( device_id, diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 20bab20c8f..de810ae3ec 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -21,10 +21,15 @@ from twisted.web.server import Request import synapse import synapse.api.auth import synapse.types -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, + NotApprovedError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet): hs.config.registration.inhibit_user_in_use_error ) + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler ) @@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet): access_token=return_dict.get("access_token"), ) + if self._require_approval: + raise NotApprovedError( + msg="This account needs to be approved by an administrator before it can be used.", + approval_notice_medium=ApprovalNoticeMedium.NONE, + ) + return 200, return_dict async def _do_appservice_registration( @@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet): "user_id": user_id, "home_server": self.hs.hostname, } - if not params.get("inhibit_login", False): + # We don't want to log the user in if we're going to deny them access because + # they need to be approved first. + if not params.get("inhibit_login", False) and not self._require_approval: device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") ( diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0843f10340..a62b4abd4e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -203,6 +203,7 @@ class DataStore( deactivated: bool = False, order_by: str = UserSortOrder.USER_ID.value, direction: str = "f", + approved: bool = True, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -217,6 +218,7 @@ class DataStore( deactivated: whether to include deactivated users order_by: the sort order of the returned list direction: sort ascending or descending + approved: whether to include approved users Returns: A tuple of a list of mappings from user to information and a count of total users. """ @@ -249,6 +251,11 @@ class DataStore( if not deactivated: filters.append("deactivated = 0") + if not approved: + # We ignore NULL values for the approved flag because these should only + # be already existing users that we consider as already approved. + filters.append("approved IS FALSE") + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" sql_base = f""" @@ -262,7 +269,7 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ac821878b0..2996d6bb4d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" - return await self.db_pool.simple_select_one( - table="users", - keyvalues={"name": user_id}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "consent_version", - "consent_ts", - "consent_server_notice_sent", - "appservice_id", - "creation_ts", - "user_type", - "deactivated", - "shadow_banned", - ], - allow_none=True, + + def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + # We could technically use simple_select_one here, but it would not perform + # the COALESCEs (unless hacked into the column names), which could yield + # confusing results. + txn.execute( + """ + SELECT + name, password_hash, is_guest, admin, consent_version, consent_ts, + consent_server_notice_sent, appservice_id, creation_ts, user_type, + deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, + COALESCE(approved, TRUE) AS approved + FROM users + WHERE name = ? + """, + (user_id,), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + if len(rows) == 0: + return None + + return rows[0] + + row = await self.db_pool.runInteraction( desc="get_user_by_id", + func=get_user_by_id_txn, ) + if row is not None: + # If we're using SQLite our boolean values will be integers. Because we + # present some of this data as is to e.g. server admins via REST APIs, we + # want to make sure we're returning the right type of data. + # Note: when adding a column name to this list, be wary of NULLable columns, + # since NULL values will be turned into False. + boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"] + for column in boolean_columns: + if not isinstance(row[column], bool): + row[column] = bool(row[column]) + + return row + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get a UserInfo object for a user by user ID. @@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return res if res else False + @cached() + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + If the user's 'approved' column is NULL, we consider it as true given it means + the user was registered when support for an approval flow was either disabled + or nonexistent. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + + def is_user_approved_txn(txn: LoggingTransaction) -> bool: + txn.execute( + """ + SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ? + """, + (user_id,), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + # We cast to bool because the value returned by the database engine might + # be an integer if we're using SQLite. + return bool(rows[0]["approved"]) + + return await self.db_pool.runInteraction( + desc="is_user_pending_approval", + func=is_user_approved_txn, + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + def update_user_approval_status_txn( + self, txn: LoggingTransaction, user_id: str, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean is turned into an int because the column is a smallint. + + Args: + txn: the current database transaction. + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"approved": approved}, + ) + + # Invalidate the caches of methods that read the value of the 'approved' flag. + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) + class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( @@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + # If support for MSC3866 is enabled and configured to require approval for new + # account, we will create new users with an 'approved' flag set to false. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + async def add_access_token_to_user( self, user_id: str, @@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool = False, user_type: Optional[str] = None, shadow_banned: bool = False, + approved: bool = False, ) -> None: """Attempts to register an account. @@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): or None for a normal user. shadow_banned: Whether the user is shadow-banned, i.e. they may be told their requests succeeded but we ignore them. + approved: Whether to consider the user has already been approved by an + administrator. Raises: StoreError if the user_id could not be registered. @@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin, user_type, shadow_banned, + approved, ) def _register_user( @@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool, user_type: Optional[str], shadow_banned: bool, + approved: bool, ) -> None: user_id_obj = UserID.from_string(user_id) now = int(self._clock.time()) + user_approved = approved or not self._require_approval + try: if was_guest: # Ensure that the guest user actually exists @@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "admin": 1 if admin else 0, "user_type": user_type, "shadow_banned": shadow_banned, + "approved": user_approved, }, ) else: @@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "admin": 1 if admin else 0, "user_type": user_type, "shadow_banned": shadow_banned, + "approved": user_approved, }, ) @@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) + async def update_user_approval_status( + self, user_id: UserID, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean will be turned into an int (in update_user_approval_status_txn) + because the column is a smallint. + + Args: + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + await self.db_pool.runInteraction( + "update_user_approval_status", + self.update_user_approval_status_txn, + user_id.to_string(), + approved, + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/03users_approved_column.sql b/synapse/storage/schema/main/delta/73/03users_approved_column.sql new file mode 100644 index 0000000000..5328d592ea --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03users_approved_column.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to the users table to track whether the user needs to be approved by an +-- administrator. +-- A NULL column means the user was created before this feature was supported by Synapse, +-- and should be considered as TRUE. +ALTER TABLE users ADD COLUMN approved BOOLEAN; diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 1847e6ad6b..4c1ce33463 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import UserTypes +from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client import devices, login, logout, profile, room, sync +from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.types import JsonDict, UserID @@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. @@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + # invalid approved + channel = self.make_request( + "GET", + self.url + "?approved=not_bool", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + # unkown order_by channel = self.make_request( "GET", @@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase): self._order_test([self.admin_user, user1, user2], "creation_ts", "f") self._order_test([user2, user1, self.admin_user], "creation_ts", "b") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_filter_out_approved(self) -> None: + """Tests that the endpoint can filter out approved users.""" + # Create our users. + self._create_users(2) + + # Get the list of users. + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + # Exclude the admin, because we don't want to accidentally un-approve the admin. + non_admin_user_ids = [ + user["name"] + for user in channel.json_body["users"] + if user["name"] != self.admin_user + ] + + self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids) + + # Select a user and un-approve them. We do this rather than the other way around + # because, since these users are created by an admin, we consider them already + # approved. + not_approved_user = non_admin_user_ids[0] + + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{not_approved_user}", + {"approved": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + # Now get the list of users again, this time filtering out approved users. + channel = self.make_request( + "GET", + self.url + "?approved=false", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + non_admin_user_ids = [ + user["name"] + for user in channel.json_body["users"] + if user["name"] != self.admin_user + ] + + # We should only have our unapproved user now. + self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) + self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def _order_test( self, expected_user_list: List[str], @@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, sync.register_servlets, + register.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_approve_account(self) -> None: + """Tests that approving an account correctly sets the approved flag for the user.""" + url = self.url_prefix % "@bob:test" + + # Create the user using the client-server API since otherwise the user will be + # marked as approved automatically. + channel = self.make_request( + "POST", + "register", + { + "username": "bob", + "password": "test", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + # Get user + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(False, channel.json_body["approved"]) + + # Approve user + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content={"approved": True}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(True, channel.json_body["approved"]) + + # Check that the user is now approved + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(True, channel.json_body["approved"]) + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_register_approved(self) -> None: + url = self.url_prefix % "@bob:test" + + # Create user + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content={"password": "abc123", "approved": True}, + ) + + self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(1, channel.json_body["approved"]) + + # Get user + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(1, channel.json_body["approved"]) + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 05355c7fb6..090cef5216 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin -from synapse.api.constants import LoginType +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase): body={"auth": {"session": session_id}}, ) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config( + { + "oidc_config": TEST_OIDC_CONFIG, + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + }, + } + ) + def test_sso_not_approved(self) -> None: + """Tests that if we register a user via SSO while requiring approval for new + accounts, we still raise the correct error before logging the user in. + """ + login_resp = self.helper.login_via_oidc("username", expected_status=403) + + self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) + self.assertEqual( + ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"] + ) + + # Check that we didn't register a device for the user during the login attempt. + devices = self.get_success( + self.hs.get_datastores().main.get_devices_by_user("@username:test") + ) + + self.assertEqual(len(devices), 0) + class RefreshAuthTests(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e2a4d98275..e801ba8c8b 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet @@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + register.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + params = { + "type": LoginType.PASSWORD, + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request("POST", LOGIN_URL, params) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") class MultiSSOTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index b781875d52..11cf3939d8 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -22,7 +22,11 @@ import pkg_resources from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync @@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dd26145bf8..c249a42bb6 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -543,8 +543,12 @@ class RestHelper: return channel.json_body - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC + def login_via_oidc( + self, + remote_user_id: str, + expected_status: int = 200, + ) -> JsonDict: + """Log in via OIDC Returns the result of the final token login. @@ -578,7 +582,9 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == HTTPStatus.OK + assert ( + channel.code == expected_status + ), f"unexpected status in response: {channel.code}" return channel.json_body def auth_via_oidc( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 853a93afab..05ea802008 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError from synapse.server import HomeServer +from synapse.types import JsonDict, UserID from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RegistrationStoreTestCase(HomeserverTestCase): @@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): "user_type": None, "deactivated": 0, "shadow_banned": 0, + "approved": 1, }, (self.get_success(self.store.get_user_by_id(self.user_id))), ) @@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase): ThreepidValidationError, ) self.assertEqual(e.value.msg, "Validation token not found or has expired", e) + + +class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): + def default_config(self) -> JsonDict: + config = super().default_config() + + # If there's already some config for this feature in the default config, it + # means we're overriding it with @override_config. In this case we don't want + # to do anything more with it. + msc3866_config = config.get("experimental_features", {}).get("msc3866") + if msc3866_config is not None: + return config + + # Require approval for all new accounts. + config["experimental_features"] = { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.user_id = "@my-user:test" + self.pwhash = "{xx1}123456789" + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": False, + } + } + } + ) + def test_approval_not_required(self) -> None: + """Tests that if we don't require approval for new accounts, newly created + accounts are automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertTrue(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approval_required(self) -> None: + """Tests that if we require approval for new accounts, newly created accounts + are not automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertFalse(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + def test_override(self) -> None: + """Tests that if we require approval for new accounts, but we explicitly say the + new user should be considered approved, they're marked as approved. + """ + self.get_success( + self.store.register_user( + self.user_id, + self.pwhash, + approved=True, + ) + ) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["approved"], 1) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approve_user(self) -> None: + """Tests that approving the user updates their approval status.""" + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + self.get_success( + self.store.update_user_approval_status( + UserID.from_string(self.user_id), True + ) + ) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) -- cgit 1.5.1 From a466164647b969efd2e85168144cd75693443c05 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 29 Sep 2022 14:55:12 +0100 Subject: Optimise get_rooms_for_user (drop with_stream_ordering) (#13787) --- changelog.d/13787.misc | 1 + synapse/handlers/device.py | 6 +- synapse/handlers/sync.py | 14 +--- synapse/storage/_base.py | 1 + synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/roommember.py | 117 +++++++++++++-------------- tests/handlers/test_sync.py | 1 + 7 files changed, 66 insertions(+), 75 deletions(-) create mode 100644 changelog.d/13787.misc (limited to 'synapse') diff --git a/changelog.d/13787.misc b/changelog.d/13787.misc new file mode 100644 index 0000000000..a9b93717f0 --- /dev/null +++ b/changelog.d/13787.misc @@ -0,0 +1 @@ +Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 03082fce42..f9cc5bddbc 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -273,11 +273,9 @@ class DeviceWorkerHandler: possibly_left = possibly_changed | possibly_left # Double check if we still share rooms with the given user. - users_rooms = await self.store.get_rooms_for_users_with_stream_ordering( - possibly_left - ) + users_rooms = await self.store.get_rooms_for_users(possibly_left) for changed_user_id, entries in users_rooms.items(): - if any(e.room_id in room_ids for e in entries): + if any(rid in room_ids for rid in entries): possibly_left.discard(changed_user_id) else: possibly_joined.discard(changed_user_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e75fc6b947..4abb9b6127 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1490,16 +1490,14 @@ class SyncHandler: since_token.device_list_key ) if changed_users is not None: - result = await self.store.get_rooms_for_users_with_stream_ordering( - changed_users - ) + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): # Check if the changed user shares any rooms with the user, # or if the changed user is the syncing user (as we always # want to include device list updates of their own devices). if user_id == changed_user_id or any( - e.room_id in joined_rooms for e in entries + rid in joined_rooms for rid in entries ): users_that_have_changed.add(changed_user_id) else: @@ -1533,13 +1531,9 @@ class SyncHandler: newly_left_users.update(left_users) # Remove any users that we still share a room with. - left_users_rooms = ( - await self.store.get_rooms_for_users_with_stream_ordering( - newly_left_users - ) - ) + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) for user_id, entries in left_users_rooms.items(): - if any(e.room_id in joined_rooms for e in entries): + if any(rid in joined_rooms for rid in entries): newly_left_users.discard(user_id) return DeviceListUpdates( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 313e8aca7d..bf42aeb8d1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "get_rooms_for_user_with_stream_ordering", (user_id,) ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,)) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index db6ce83a2b..3b8ed1f7ee 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_rooms_for_user_with_stream_ordering.invalidate( (data.state_key,) ) + self.get_rooms_for_user.invalidate((data.state_key,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 8ada3cdac3..982e1f08e3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Callable, Collection, Dict, FrozenSet, @@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList -from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) - @cachedList( - cached_method_name="get_rooms_for_user_with_stream_ordering", - list_name="user_ids", - ) - async def get_rooms_for_users_with_stream_ordering( - self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - """A batched version of `get_rooms_for_user_with_stream_ordering`. - - Returns: - Map from user_id to set of rooms that is currently in. - """ - return await self.db_pool.runInteraction( - "get_rooms_for_users_with_stream_ordering", - self._get_rooms_for_users_with_stream_ordering_txn, - user_ids, - ) - - def _get_rooms_for_users_with_stream_ordering_txn( - self, txn: LoggingTransaction, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - - clause, args = make_in_list_sql_clause( - self.database_engine, - "c.state_key", - user_ids, - ) - - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - - txn.execute(sql, [Membership.JOIN] + args) - - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { - user_id: set() for user_id in user_ids - } - for user_id, room_id, instance, stream_id in txn: - result[user_id].add( - GetRoomsForUserWithStreamOrdering( - room_id, PersistedEventPosition(instance, stream_id) - ) - ) - - return {user_id: frozenset(v) for user_id, v in result.items()} - async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: @@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0] for row in txn} - @cancellable - async def get_rooms_for_user( - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None - ) -> FrozenSet[str]: + @cached(max_entries=500000, iterable=True) + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = await self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( + (user_id,), + None, + update_metrics=False, + ) + if rooms: + return frozenset(r.room_id for r in rooms) + + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_for_user", ) - return frozenset(r.room_id for r in rooms) + + return frozenset(room_ids) + + @cachedList( + cached_method_name="get_rooms_for_user", + list_name="user_ids", + ) + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched version of `get_rooms_for_user`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "state_key", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", + ) + + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} + + for row in rows: + user_rooms[row["state_key"]].add(row["room_id"]) + + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @cached(max_entries=10000) async def does_pair_of_users_share_a_room( diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e3f38fbcc5..ab5c101eb7 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + self.store.get_rooms_for_user.invalidate_all() self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() -- cgit 1.5.1 From ebd9e2dac6495a1857617d1a76c9259a988f8bb4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Sep 2022 16:12:09 +0100 Subject: Implement push rule evaluation in Rust. (#13838) --- changelog.d/13838.misc | 1 + rust/Cargo.toml | 4 +- rust/benches/evaluator.rs | 149 ++++++++++++ rust/benches/glob.rs | 40 ++++ rust/build.rs | 2 +- rust/src/push/base_rules.rs | 1 + rust/src/push/evaluator.rs | 374 +++++++++++++++++++++++++++++++ rust/src/push/mod.rs | 28 ++- rust/src/push/utils.rs | 215 ++++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 19 +- synapse/push/bulk_push_rule_evaluator.py | 44 ++-- synapse/push/httppusher.py | 39 +++- synapse/push/push_rule_evaluator.py | 361 ----------------------------- tests/push/test_push_rule_evaluator.py | 20 +- 14 files changed, 894 insertions(+), 403 deletions(-) create mode 100644 changelog.d/13838.misc create mode 100644 rust/benches/evaluator.rs create mode 100644 rust/benches/glob.rs create mode 100644 rust/src/push/evaluator.rs create mode 100644 rust/src/push/utils.rs delete mode 100644 synapse/push/push_rule_evaluator.py (limited to 'synapse') diff --git a/changelog.d/13838.misc b/changelog.d/13838.misc new file mode 100644 index 0000000000..28bddb7059 --- /dev/null +++ b/changelog.d/13838.misc @@ -0,0 +1 @@ +Port push rules to using Rust. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 44263bf77e..cffaa5b51b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -11,7 +11,9 @@ rust-version = "1.58.1" [lib] name = "synapse" -crate-type = ["cdylib"] +# We generate a `cdylib` for Python and a standard `lib` for running +# tests/benchmarks. +crate-type = ["lib", "cdylib"] [package.metadata.maturin] # This is where we tell maturin where to place the built library. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs new file mode 100644 index 0000000000..ed411461d1 --- /dev/null +++ b/rust/benches/evaluator.rs @@ -0,0 +1,149 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(test)] +use synapse::push::{ + evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, +}; +use test::Bencher; + +extern crate test; + +#[bench] +fn bench_match_exact(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "room_id".into(), + pattern: Some("!room:server".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_match_word(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "content.body".into(), + pattern: Some("test".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_match_word_miss(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "content.body".into(), + pattern: Some("foobar".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(!matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_eval_message(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let rules = + FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false); + + b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); +} diff --git a/rust/benches/glob.rs b/rust/benches/glob.rs new file mode 100644 index 0000000000..b6697d9285 --- /dev/null +++ b/rust/benches/glob.rs @@ -0,0 +1,40 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(test)] + +use synapse::push::utils::{glob_to_regex, GlobMatchType}; +use test::Bencher; + +extern crate test; + +#[bench] +fn bench_whole(b: &mut Bencher) { + b.iter(|| glob_to_regex("test", GlobMatchType::Whole)); +} + +#[bench] +fn bench_word(b: &mut Bencher) { + b.iter(|| glob_to_regex("test", GlobMatchType::Word)); +} + +#[bench] +fn bench_whole_wildcard_run(b: &mut Bencher) { + b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole)); +} + +#[bench] +fn bench_word_wildcard_run(b: &mut Bencher) { + b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole)); +} diff --git a/rust/build.rs b/rust/build.rs index 2117975e56..ef370e6b41 100644 --- a/rust/build.rs +++ b/rust/build.rs @@ -22,7 +22,7 @@ fn main() -> Result<(), std::io::Error> { for entry in entries { if entry.is_dir() { - dirs.push(entry) + dirs.push(entry); } else { paths.push(entry.to_str().expect("valid rust paths").to_string()); } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 7c62bc4849..bb59676bde 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -262,6 +262,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ priority_class: 1, conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { rel_type: Cow::Borrowed("m.thread"), + event_type_pattern: None, sender: None, sender_type: Some(Cow::Borrowed("user_id")), })]), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs new file mode 100644 index 0000000000..efe88ec76e --- /dev/null +++ b/rust/src/push/evaluator.rs @@ -0,0 +1,374 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet}, +}; + +use anyhow::{Context, Error}; +use lazy_static::lazy_static; +use log::warn; +use pyo3::prelude::*; +use regex::Regex; + +use super::{ + utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, + Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, +}; + +lazy_static! { + /// Used to parse the `is` clause in the room member count condition. + static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex"); +} + +/// Allows running a set of push rules against a particular event. +#[pyclass] +pub struct PushRuleEvaluator { + /// A mapping of "flattened" keys to string values in the event, e.g. + /// includes things like "type" and "content.msgtype". + flattened_keys: BTreeMap, + + /// The "content.body", if any. + body: String, + + /// The number of users in the room. + room_member_count: u64, + + /// The `notifications` section of the current power levels in the room. + notification_power_levels: BTreeMap, + + /// The relations related to the event as a mapping from relation type to + /// set of sender/event type 2-tuples. + relations: BTreeMap>, + + /// Is running "relation" conditions enabled? + relation_match_enabled: bool, + + /// The power level of the sender of the event, or None if event is an + /// outlier. + sender_power_level: Option, +} + +#[pymethods] +impl PushRuleEvaluator { + /// Create a new `PushRuleEvaluator`. See struct docstring for details. + #[new] + pub fn py_new( + flattened_keys: BTreeMap, + room_member_count: u64, + sender_power_level: Option, + notification_power_levels: BTreeMap, + relations: BTreeMap>, + relation_match_enabled: bool, + ) -> Result { + let body = flattened_keys + .get("content.body") + .cloned() + .unwrap_or_default(); + + Ok(PushRuleEvaluator { + flattened_keys, + body, + room_member_count, + notification_power_levels, + relations, + relation_match_enabled, + sender_power_level, + }) + } + + /// Run the evaluator with the given push rules, for the given user ID and + /// display name of the user. + /// + /// Passing in None will skip evaluating rules matching user ID and display + /// name. + /// + /// Returns the set of actions, if any, that match (filtering out any + /// `dont_notify` actions). + pub fn run( + &self, + push_rules: &FilteredPushRules, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Vec { + 'outer: for (push_rule, enabled) in push_rules.iter() { + if !enabled { + continue; + } + + for condition in push_rule.conditions.iter() { + match self.match_condition(condition, user_id, display_name) { + Ok(true) => {} + Ok(false) => continue 'outer, + Err(err) => { + warn!("Condition match failed {err}"); + continue 'outer; + } + } + } + + let actions = push_rule + .actions + .iter() + // Filter out "dont_notify" actions, as we don't store them. + .filter(|a| **a != Action::DontNotify) + .cloned() + .collect(); + + return actions; + } + + Vec::new() + } + + /// Check if the given condition matches. + fn matches( + &self, + condition: Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> bool { + match self.match_condition(&condition, user_id, display_name) { + Ok(true) => true, + Ok(false) => false, + Err(err) => { + warn!("Condition match failed {err}"); + false + } + } + } +} + +impl PushRuleEvaluator { + /// Match a given `Condition` for a push rule. + pub fn match_condition( + &self, + condition: &Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Result { + let known_condition = match condition { + Condition::Known(known) => known, + Condition::Unknown(_) => { + return Ok(false); + } + }; + + let result = match known_condition { + KnownCondition::EventMatch(event_match) => { + self.match_event_match(event_match, user_id)? + } + KnownCondition::ContainsDisplayName => { + if let Some(dn) = display_name { + if !dn.is_empty() { + get_glob_matcher(dn, GlobMatchType::Word)?.is_match(&self.body)? + } else { + // We specifically ignore empty display names, as otherwise + // they would always match. + false + } + } else { + false + } + } + KnownCondition::RoomMemberCount { is } => { + if let Some(is) = is { + self.match_member_count(is)? + } else { + false + } + } + KnownCondition::SenderNotificationPermission { key } => { + if let Some(sender_power_level) = &self.sender_power_level { + let required_level = self + .notification_power_levels + .get(key.as_ref()) + .copied() + .unwrap_or(50); + + *sender_power_level >= required_level + } else { + false + } + } + KnownCondition::RelationMatch { + rel_type, + event_type_pattern, + sender, + sender_type, + } => { + self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? + } + }; + + Ok(result) + } + + /// Evaluates a relation condition. + fn match_relations( + &self, + rel_type: &str, + sender: &Option>, + sender_type: &Option>, + user_id: Option<&str>, + event_type_pattern: &Option>, + ) -> Result { + // First check if relation matching is enabled... + if !self.relation_match_enabled { + return Ok(false); + } + + // ... and if there are any relations to match against. + let relations = if let Some(relations) = self.relations.get(rel_type) { + relations + } else { + return Ok(false); + }; + + // Extract the sender pattern from the condition + let sender_pattern = if let Some(sender) = sender { + Some(sender.as_ref()) + } else if let Some(sender_type) = sender_type { + if sender_type == "user_id" { + if let Some(user_id) = user_id { + Some(user_id) + } else { + return Ok(false); + } + } else { + warn!("Unrecognized sender_type: {sender_type}"); + return Ok(false); + } + } else { + None + }; + + let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { + Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) + } else { + None + }; + + let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { + Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) + } else { + None + }; + + for (relation_sender, event_type) in relations { + if let Some(pattern) = &mut sender_compiled_pattern { + if !pattern.is_match(relation_sender)? { + continue; + } + } + + if let Some(pattern) = &mut type_compiled_pattern { + if !pattern.is_match(event_type)? { + continue; + } + } + + return Ok(true); + } + + Ok(false) + } + + /// Evaluates a `event_match` condition. + fn match_event_match( + &self, + event_match: &EventMatchCondition, + user_id: Option<&str>, + ) -> Result { + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if event_match.key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + + /// Match the member count against an 'is' condition + /// The `is` condition can be things like '>2', '==3' or even just '4'. + fn match_member_count(&self, is: &str) -> Result { + let captures = INEQUALITY_EXPR.captures(is).context("bad 'is' clause")?; + let ineq = captures.get(1).map_or("==", |m| m.as_str()); + let rhs: u64 = captures + .get(2) + .context("missing number")? + .as_str() + .parse()?; + + let matches = match ineq { + "" | "==" => self.room_member_count == rhs, + "<" => self.room_member_count < rhs, + ">" => self.room_member_count > rhs, + ">=" => self.room_member_count >= rhs, + "<=" => self.room_member_count <= rhs, + _ => false, + }; + + Ok(matches) + } +} + +#[test] +fn push_rule_evaluator() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); + + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + assert_eq!(result.len(), 3); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index de6764e7c5..30fffc31ad 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -42,7 +42,6 @@ //! //! The set of "base rules" are the list of rules that every user has by default. A //! user can modify their copy of the push rules in one of three ways: -//! //! 1. Adding a new push rule of a certain kind //! 2. Changing the actions of a base rule //! 3. Enabling/disabling a base rule. @@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{Context, Error}; use log::warn; use pyo3::prelude::*; -use pythonize::pythonize; +use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; use serde_json::Value; +use self::evaluator::PushRuleEvaluator; + mod base_rules; +pub mod evaluator; +pub mod utils; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?; m.add_submodule(child_module)?; @@ -274,6 +278,8 @@ pub enum KnownCondition { #[serde(rename = "org.matrix.msc3772.relation_match")] RelationMatch { rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none", rename = "type")] + event_type_pattern: Option>, #[serde(skip_serializing_if = "Option::is_none")] sender: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -287,20 +293,26 @@ impl IntoPy for Condition { } } +impl<'source> FromPyObject<'source> for Condition { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(depythonize(ob)?) + } +} + /// The body of a [`Condition::EventMatch`] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventMatchCondition { - key: Cow<'static, str>, + pub key: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] - pattern: Option>, + pub pattern: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pattern_type: Option>, + pub pattern_type: Option>, } /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] -struct PushRules { +pub struct PushRules { /// Custom push rules that override a base rule. overridden_base_rules: HashMap, PushRule>, @@ -319,7 +331,7 @@ struct PushRules { #[pymethods] impl PushRules { #[new] - fn new(rules: Vec) -> PushRules { + pub fn new(rules: Vec) -> PushRules { let mut push_rules: PushRules = Default::default(); for rule in rules { @@ -396,7 +408,7 @@ pub struct FilteredPushRules { #[pymethods] impl FilteredPushRules { #[new] - fn py_new( + pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, msc3786_enabled: bool, diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs new file mode 100644 index 0000000000..8759340473 --- /dev/null +++ b/rust/src/push/utils.rs @@ -0,0 +1,215 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::bail; +use anyhow::Context; +use anyhow::Error; +use lazy_static::lazy_static; +use regex; +use regex::Regex; +use regex::RegexBuilder; + +lazy_static! { + /// Matches runs of non-wildcard characters followed by wildcard characters. + static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex"); +} + +/// Extract the localpart from a Matrix style ID +pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> { + let (localpart, _) = id + .split_once(':') + .with_context(|| format!("ID does not contain colon: {id}"))?; + + // We need to strip off the first character, which is the ID type. + if localpart.is_empty() { + bail!("Invalid ID {id}"); + } + + Ok(&localpart[1..]) +} + +/// Used by `glob_to_regex` to specify what to match the regex against. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GlobMatchType { + /// The generated regex will match against the entire input. + Whole, + /// The generated regex will match against words. + Word, +} + +/// Convert a "glob" style expression to a regex, anchoring either to the entire +/// input or to individual words. +pub fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result { + let mut chunks = Vec::new(); + + // Patterns with wildcards must be simplified to avoid performance cliffs + // - The glob `?**?**?` is equivalent to the glob `???*` + // - The glob `???*` is equivalent to the regex `.{3,}` + for captures in WILDCARD_RUN.captures_iter(glob) { + if let Some(chunk) = captures.get(1) { + chunks.push(regex::escape(chunk.as_str())); + } + + if let Some(wildcards) = captures.get(2) { + if wildcards.as_str() == "" { + continue; + } + + let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count(); + + if wildcards.as_str().contains('*') { + chunks.push(format!(".{{{question_marks},}}")); + } else { + chunks.push(format!(".{{{question_marks}}}")); + } + } + } + + let joined = chunks.join(""); + + let regex_str = match match_type { + GlobMatchType::Whole => format!(r"\A{joined}\z"), + + // `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word + // character. + GlobMatchType::Word => format!(r"(?:^|\b|\W){joined}(?:\b|\W|$)"), + }; + + Ok(RegexBuilder::new(®ex_str) + .case_insensitive(true) + .build()?) +} + +/// Compiles the glob into a `Matcher`. +pub fn get_glob_matcher(glob: &str, match_type: GlobMatchType) -> Result { + // There are a number of shortcuts we can make if the glob doesn't contain a + // wild card. + let matcher = if glob.contains(['*', '?']) { + let regex = glob_to_regex(glob, match_type)?; + Matcher::Regex(regex) + } else if match_type == GlobMatchType::Whole { + // If there aren't any wildcards and we're matching the whole thing, + // then we simply can do a case-insensitive string match. + Matcher::Whole(glob.to_lowercase()) + } else { + // Otherwise, if we're matching against words then can first check + // if the haystack contains the glob at all. + Matcher::Word { + word: glob.to_lowercase(), + regex: None, + } + }; + + Ok(matcher) +} + +/// Matches against a glob +pub enum Matcher { + /// Plain regex matching. + Regex(Regex), + + /// Case-insensitive equality. + Whole(String), + + /// Word matching. `regex` is a cache of calling [`glob_to_regex`] on word. + Word { word: String, regex: Option }, +} + +impl Matcher { + /// Checks if the glob matches the given haystack. + pub fn is_match(&mut self, haystack: &str) -> Result { + // We want to to do case-insensitive matching, so we convert to + // lowercase first. + let haystack = haystack.to_lowercase(); + + match self { + Matcher::Regex(regex) => Ok(regex.is_match(&haystack)), + Matcher::Whole(whole) => Ok(whole == &haystack), + Matcher::Word { word, regex } => { + // If we're looking for a literal word, then we first check if + // the haystack contains the word as a substring. + if !haystack.contains(&*word) { + return Ok(false); + } + + // If it does contain the word as a substring, then we need to + // check if it is an actual word by testing it against the regex. + let regex = if let Some(regex) = regex { + regex + } else { + let compiled_regex = glob_to_regex(word, GlobMatchType::Word)?; + regex.insert(compiled_regex) + }; + + Ok(regex.is_match(&haystack)) + } + } + } +} + +#[test] +fn test_get_domain_from_id() { + get_localpart_from_id("").unwrap_err(); + get_localpart_from_id(":").unwrap_err(); + get_localpart_from_id(":asd").unwrap_err(); + get_localpart_from_id("::as::asad").unwrap_err(); + + assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test"); + assert_eq!(get_localpart_from_id("@:").unwrap(), ""); + assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test"); +} + +#[test] +fn tset_glob() -> Result<(), Error> { + assert_eq!( + glob_to_regex("simple", GlobMatchType::Whole)?.as_str(), + r"\Asimple\z" + ); + assert_eq!( + glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{0,}\z" + ); + assert_eq!( + glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{1}\z" + ); + assert_eq!( + glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{2,}\z" + ); + assert_eq!( + glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{3}\z" + ); + + assert_eq!( + glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(), + r"\Aescape\.\z" + ); + + assert!(glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simple")); + assert!(!glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple?", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simple")); + + assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple.")); + assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple")); + assert!(!glob_to_regex("simple", GlobMatchType::Word)?.is_match("simples")); + + assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("Some @user:foo test")); + assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("@user:foo")); + + Ok(()) +} diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 93c4e69d42..fffb8419c6 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -1,4 +1,4 @@ -from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -35,3 +35,20 @@ class FilteredPushRules: def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... + +class PushRuleEvaluator: + def __init__( + self, + flattened_keys: Mapping[str, str], + room_member_count: int, + sender_power_level: Optional[int], + notification_power_levels: Mapping[str, int], + relations: Mapping[str, Set[Tuple[str, str]]], + relation_match_enabled: bool, + ): ... + def run( + self, + push_rules: FilteredPushRules, + user_id: Optional[str], + display_name: Optional[str], + ) -> Collection[dict]: ... diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 32313e3bcf..60f3129005 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -17,6 +17,7 @@ import itertools import logging from typing import ( TYPE_CHECKING, + Any, Collection, Dict, Iterable, @@ -37,13 +38,11 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule +from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state -from .push_rule_evaluator import PushRuleEvaluatorForEvent - if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,11 +289,11 @@ class BulkPushRuleEvaluator: if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id - evaluator = PushRuleEvaluatorForEvent( - event, + evaluator = PushRuleEvaluator( + _flatten_dict(event), room_member_count, sender_power_level, - power_levels, + power_levels.get("notifications", {}), relations, self._relations_match_enabled, ) @@ -338,17 +337,10 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - for rule, enabled in rules.rules(): - if not enabled: - continue - - matches = evaluator.check_conditions(rule.conditions, uid, display_name) - if matches: - actions = [x for x in rule.actions if x != "dont_notify"] - if actions and "notify" in actions: - # Push rules say we should notify the user of this event - actions_by_user[uid] = actions - break + actions = evaluator.run(rules, uid, display_name) + if "notify" in actions: + # Push rules say we should notify the user of this event + actions_by_user[uid] = actions # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist @@ -365,3 +357,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] + + +def _flatten_dict( + d: Union[EventBase, Mapping[str, Any]], + prefix: Optional[List[str]] = None, + result: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + if prefix is None: + prefix = [] + if result is None: + result = {} + for key, value in d.items(): + if isinstance(value, str): + result[".".join(prefix + [key])] = value.lower() + elif isinstance(value, Mapping): + _flatten_dict(value, prefix=(prefix + [key]), result=result) + + return result diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e96fb45e9f..b048b03a74 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from prometheus_client import Counter @@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction -from . import push_rule_evaluator, push_tools +from . import push_tools if TYPE_CHECKING: from synapse.server import HomeServer @@ -56,6 +56,39 @@ http_badges_failed_counter = Counter( ) +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) + return tweaks + + class HttpPusher(Pusher): INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes MAX_BACKOFF_SEC = 60 * 60 @@ -281,7 +314,7 @@ class HttpPusher(Pusher): if "notify" not in push_action.actions: return True - tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) + tweaks = tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( self.hs.get_datastores().main, self.user_id, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py deleted file mode 100644 index f8176c5a42..0000000000 --- a/synapse/push/push_rule_evaluator.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import ( - Any, - Dict, - List, - Mapping, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Union, -) - -from matrix_common.regex import glob_to_regex, to_word_pattern - -from synapse.events import EventBase -from synapse.types import UserID -from synapse.util.caches.lrucache import LruCache - -logger = logging.getLogger(__name__) - - -GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") -IS_GLOB = re.compile(r"[\?\*\[\]]") -INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") - - -def _room_member_count(condition: Mapping[str, Any], room_member_count: int) -> bool: - return _test_ineq_condition(condition, room_member_count) - - -def _sender_notification_permission( - condition: Mapping[str, Any], - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], -) -> bool: - if sender_power_level is None: - return False - - notif_level_key = condition.get("key") - if notif_level_key is None: - return False - - notif_levels = power_levels.get("notifications", {}) - assert isinstance(notif_levels, dict) - room_notif_level = notif_levels.get(notif_level_key, 50) - - return sender_power_level >= room_notif_level - - -def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool: - if "is" not in condition: - return False - m = INEQUALITY_EXPR.match(condition["is"]) - if not m: - return False - ineq = m.group(1) - rhs = m.group(2) - if not rhs.isdigit(): - return False - rhs_int = int(rhs) - - if ineq == "" or ineq == "==": - return number == rhs_int - elif ineq == "<": - return number < rhs_int - elif ineq == ">": - return number > rhs_int - elif ineq == ">=": - return number >= rhs_int - elif ineq == "<=": - return number <= rhs_int - else: - return False - - -def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: - """ - Converts a list of actions into a `tweaks` dict (which can then be passed to - the push gateway). - - This function ignores all actions other than `set_tweak` actions, and treats - absent `value`s as `True`, which agrees with the only spec-defined treatment - of absent `value`s (namely, for `highlight` tweaks). - - Args: - actions: list of actions - e.g. [ - {"set_tweak": "a", "value": "AAA"}, - {"set_tweak": "b", "value": "BBB"}, - {"set_tweak": "highlight"}, - "notify" - ] - - Returns: - dictionary of tweaks for those actions - e.g. {"a": "AAA", "b": "BBB", "highlight": True} - """ - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if "set_tweak" in a: - # value is allowed to be absent in which case the value assumed - # should be True. - tweaks[a["set_tweak"]] = a.get("value", True) - return tweaks - - -class PushRuleEvaluatorForEvent: - def __init__( - self, - event: EventBase, - room_member_count: int, - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], - relations: Dict[str, Set[Tuple[str, str]]], - relations_match_enabled: bool, - ): - self._event = event - self._room_member_count = room_member_count - self._sender_power_level = sender_power_level - self._power_levels = power_levels - self._relations = relations - self._relations_match_enabled = relations_match_enabled - - # Maps strings of e.g. 'content.body' -> event["content"]["body"] - self._value_cache = _flatten_dict(event) - - # Maps cache keys to final values. - self._condition_cache: Dict[str, bool] = {} - - def check_conditions( - self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's conditions/user ID/display name match the event. - - Args: - conditions: The user's conditions to match. - uid: The user's MXID. - display_name: The display name. - - Returns: - True if all conditions match the event, False otherwise. - """ - for cond in conditions: - _cache_key = cond.get("_cache_key", None) - if _cache_key: - res = self._condition_cache.get(_cache_key, None) - if res is False: - return False - elif res is True: - continue - - res = self.matches(cond, uid, display_name) - if _cache_key: - self._condition_cache[_cache_key] = bool(res) - - if not res: - return False - - return True - - def matches( - self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's condition/user ID/display name match the event. - - Args: - condition: The user's condition to match. - uid: The user's MXID. - display_name: The display name, or None if there is not one. - - Returns: - True if the condition matches the event, False otherwise. - """ - if condition["kind"] == "event_match": - return self._event_match(condition, user_id) - elif condition["kind"] == "contains_display_name": - return self._contains_display_name(display_name) - elif condition["kind"] == "room_member_count": - return _room_member_count(condition, self._room_member_count) - elif condition["kind"] == "sender_notification_permission": - return _sender_notification_permission( - condition, self._sender_power_level, self._power_levels - ) - elif ( - condition["kind"] == "org.matrix.msc3772.relation_match" - and self._relations_match_enabled - ): - return self._relation_match(condition, user_id) - else: - # XXX This looks incorrect -- we have reached an unknown condition - # kind and are unconditionally returning that it matches. Note - # that it seems possible to provide a condition to the /pushrules - # endpoint with an unknown kind, see _rule_tuple_from_request_object. - return True - - def _event_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - pattern = condition.get("pattern", None) - - if not pattern: - pattern_type = condition.get("pattern_type", None) - if pattern_type == "user_id": - pattern = user_id - elif pattern_type == "user_localpart": - pattern = UserID.from_string(user_id).localpart - - if not pattern: - logger.warning("event_match condition with no pattern") - return False - - # XXX: optimisation: cache our pattern regexps - if condition["key"] == "content.body": - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - return _glob_matches(pattern, body, word_boundary=True) - else: - haystack = self._value_cache.get(condition["key"], None) - if haystack is None: - return False - - return _glob_matches(pattern, haystack) - - def _contains_display_name(self, display_name: Optional[str]) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - display_name: The display name, or None if there is not one. - - Returns: - True if the display name is found in the event body, False otherwise. - """ - if not display_name: - return False - - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - # Similar to _glob_matches, but do not treat display_name as a glob. - r = regex_cache.get((display_name, False, True), None) - if not r: - r1 = re.escape(display_name) - r1 = to_word_pattern(r1) - r = re.compile(r1, flags=re.IGNORECASE) - regex_cache[(display_name, False, True)] = r - - return bool(r.search(body)) - - def _relation_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "relation_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - rel_type = condition.get("rel_type") - if not rel_type: - logger.warning("relation_match condition missing rel_type") - return False - - sender_pattern = condition.get("sender") - if sender_pattern is None: - sender_type = condition.get("sender_type") - if sender_type == "user_id": - sender_pattern = user_id - type_pattern = condition.get("type") - - # If any other relations matches, return True. - for sender, event_type in self._relations.get(rel_type, ()): - if sender_pattern and not _glob_matches(sender_pattern, sender): - continue - if type_pattern and not _glob_matches(type_pattern, event_type): - continue - # All values must have matched. - return True - - # No relations matched. - return False - - -# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( - 50000, "regex_push_cache" -) - - -def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: - """Tests if value matches glob. - - Args: - glob - value: String to test against glob. - word_boundary: Whether to match against word boundaries or entire - string. Defaults to False. - """ - - try: - r = regex_cache.get((glob, True, word_boundary), None) - if not r: - r = glob_to_regex(glob, word_boundary=word_boundary) - regex_cache[(glob, True, word_boundary)] = r - return bool(r.search(value)) - except re.error: - logger.warning("Failed to parse glob to regex: %r", glob) - return False - - -def _flatten_dict( - d: Union[EventBase, Mapping[str, Any]], - prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, -) -> Dict[str, str]: - if prefix is None: - prefix = [] - if result is None: - result = {} - for key, value in d.items(): - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, Mapping): - _flatten_dict(value, prefix=(prefix + [key]), result=result) - - return result diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 718f489577..b8308cbc05 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -23,11 +23,12 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.events import FrozenEvent -from synapse.push import push_rule_evaluator -from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.push.bulk_push_rule_evaluator import _flatten_dict +from synapse.push.httppusher import tweaks_for_actions from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex +from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict from synapse.util import Clock @@ -41,7 +42,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): content: JsonDict, relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, relations_match_enabled: bool = False, - ) -> PushRuleEvaluatorForEvent: + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -56,12 +57,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count = 0 sender_power_level = 0 power_levels: Dict[str, Union[int, Dict[str, int]]] = {} - return PushRuleEvaluatorForEvent( - event, + return PushRuleEvaluator( + _flatten_dict(event), room_member_count, sender_power_level, - power_levels, - relations or set(), + power_levels.get("notifications", {}), + relations or {}, relations_match_enabled, ) @@ -293,7 +294,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ] self.assertEqual( - push_rule_evaluator.tweaks_for_actions(actions), + tweaks_for_actions(actions), {"sound": "default", "highlight": True}, ) @@ -304,9 +305,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): evaluator = self._get_evaluator( {}, {"m.annotation": {("@user:test", "m.reaction")}} ) - condition = {"kind": "relation_match"} - # Oddly, an unknown condition always matches. - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) # A push rule evaluator with the experimental rule enabled. evaluator = self._get_evaluator( -- cgit 1.5.1 From 15754d720feb3af88d97a2dafd0b05633abf42f5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 29 Sep 2022 19:10:47 +0100 Subject: Update UPSERT comment now that native upserts are the default (#13924) --- changelog.d/13924.misc | 1 + synapse/storage/database.py | 60 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 changelog.d/13924.misc (limited to 'synapse') diff --git a/changelog.d/13924.misc b/changelog.d/13924.misc new file mode 100644 index 0000000000..7770b6f03f --- /dev/null +++ b/changelog.d/13924.misc @@ -0,0 +1 @@ +Update an innaccurate comment in Synapse's upsert database helper. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6cc88aad32..bb28ded1b5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1141,17 +1141,57 @@ class DatabasePool: desc: str = "simple_upsert", lock: bool = True, ) -> bool: - """ + """Insert a row with values + insertion_values; on conflict, update with values. + + All of our supported databases accept the nonstandard "upsert" statement in + their dialect of SQL. We call this a "native upsert". The syntax looks roughly + like: + + INSERT INTO table VALUES (values + insertion_values) + ON CONFLICT (keyvalues) + DO UPDATE SET (values); -- overwrite `values` columns only + + If (values) is empty, the resulting query is slighlty simpler: + + INSERT INTO table VALUES (insertion_values) + ON CONFLICT (keyvalues) + DO NOTHING; -- do not overwrite any columns + + This function is a helper to build such queries. + + In order for upserts to make sense, the database must be able to determine when + an upsert CONFLICTs with an existing row. Postgres and SQLite ensure this by + requiring that a unique index exist on the column names used to detect a + conflict (i.e. `keyvalues.keys()`). + + If there is no such index, we can "emulate" an upsert with a SELECT followed + by either an INSERT or an UPDATE. This is unsafe: we cannot make the same + atomicity guarantees that a native upsert can and are very vulnerable to races + and crashes. Therefore if we wish to upsert without an appropriate unique index, + we must either: + + 1. Acquire a table-level lock before the emulated upsert (`lock=True`), or + 2. VERY CAREFULLY ensure that we are the only thread and worker which will be + writing to this table, in which case we can proceed without a lock + (`lock=False`). + + Generally speaking, you should use `lock=True`. If the table in question has a + unique index[*], this class will use a native upsert (which is atomic and so can + ignore the `lock` argument). Otherwise this class will use an emulated upsert, + in which case we want the safer option unless we been VERY CAREFUL. + + [*]: Some tables have unique indices added to them in the background. Those + tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES, + where `T` maps to the background update that adds a unique index to `T`. + This dictionary is maintained by hand. + + At runtime, we constantly check to see if each of these background updates + has run. If so, we deem the coresponding table safe to upsert into, because + we can now use a native insert to do so. If not, we deem the table unsafe + to upsert into and require an emulated upsert. - `lock` should generally be set to True (the default), but can be set - to False if either of the following are true: - 1. there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - 2. we somehow know that we are the only thread which will be updating - this table. - As an additional note, this parameter only matters for old SQLite versions - because we will use native upserts otherwise. + Tables that do not appear in this dictionary are assumed to have an + appropriate unique index and therefore be safe to upsert into. Args: table: The table to upsert into -- cgit 1.5.1 From 6f0c3e669da458e838e7b4b165a13e8a5312d6d0 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 29 Sep 2022 21:16:08 +0100 Subject: Don't require `setuptools_rust` at runtime (#13952) --- changelog.d/13952.bugfix | 1 + synapse/util/check_dependencies.py | 17 ++++++++++++++++- tests/util/test_check_dependencies.py | 20 ++++++++++++++++++-- 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 changelog.d/13952.bugfix (limited to 'synapse') diff --git a/changelog.d/13952.bugfix b/changelog.d/13952.bugfix new file mode 100644 index 0000000000..a6af20f051 --- /dev/null +++ b/changelog.d/13952.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.68.0 where Synapse would require `setuptools_rust` at runtime, even though the package is only required at build time. diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 66f1da7502..3b1e205700 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -66,6 +66,21 @@ def _is_dev_dependency(req: Requirement) -> bool: ) +def _should_ignore_runtime_requirement(req: Requirement) -> bool: + # This is a build-time dependency. Irritatingly, `poetry build` ignores the + # requirements listed in the [build-system] section of pyproject.toml, so in order + # to support `poetry install --no-dev` we have to mark it as a runtime dependency. + # See discussion on https://github.com/python-poetry/poetry/issues/6154 (it sounds + # like the poetry authors don't consider this a bug?) + # + # In any case, workaround this by ignoring setuptools_rust here. (It might be + # slightly cleaner to put `setuptools_rust` in a `build` extra or similar, but for + # now let's do something quick and dirty. + if req.name == "setuptools_rust": + return True + return False + + class Dependency(NamedTuple): requirement: Requirement must_be_installed: bool @@ -77,7 +92,7 @@ def _generic_dependencies() -> Iterable[Dependency]: assert requirements is not None for raw_requirement in requirements: req = Requirement(raw_requirement) - if _is_dev_dependency(req): + if _is_dev_dependency(req) or _should_ignore_runtime_requirement(req): continue # https://packaging.pypa.io/en/latest/markers.html#usage notes that diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 5d1aa025d1..6913de24b9 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -40,7 +40,10 @@ class TestDependencyChecker(TestCase): def mock_installed_package( self, distribution: Optional[DummyDistribution] ) -> Generator[None, None, None]: - """Pretend that looking up any distribution yields the given `distribution`.""" + """Pretend that looking up any package yields the given `distribution`. + + If `distribution = None`, we pretend that the package is not installed. + """ def mock_distribution(name: str): if distribution is None: @@ -81,7 +84,7 @@ class TestDependencyChecker(TestCase): self.assertRaises(DependencyException, check_requirements) def test_checks_ignore_dev_dependencies(self) -> None: - """Bot generic and per-extra checks should ignore dev dependencies.""" + """Both generic and per-extra checks should ignore dev dependencies.""" with patch( "synapse.util.check_dependencies.metadata.requires", return_value=["dummypkg >= 1; extra == 'mypy'"], @@ -142,3 +145,16 @@ class TestDependencyChecker(TestCase): with self.mock_installed_package(new_release_candidate): # should not raise check_requirements() + + def test_setuptools_rust_ignored(self) -> None: + """Test a workaround for a `poetry build` problem. Reproduces #13926.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["setuptools_rust >= 1.3"], + ): + with self.mock_installed_package(None): + # should not raise, even if setuptools_rust is not installed + check_requirements() + with self.mock_installed_package(old): + # We also ignore old versions of setuptools_rust + check_requirements() -- cgit 1.5.1 From 1cc2ca81badb9c5161d219dfc9a273a338adedd2 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 30 Sep 2022 11:27:21 +0100 Subject: Add missing version information in the ModuleApi (#13947) --- changelog.d/13947.feature | 1 + synapse/module_api/__init__.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/13947.feature (limited to 'synapse') diff --git a/changelog.d/13947.feature b/changelog.d/13947.feature new file mode 100644 index 0000000000..a0b3cfe18c --- /dev/null +++ b/changelog.d/13947.feature @@ -0,0 +1 @@ +Add cache invalidation across workers to module API. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 59755bff6d..b7b2d3b8c5 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -842,6 +842,8 @@ class ModuleApi: however invalidation that needs to go to other workers needs to call `invalidate_cache` on the module API instead. + Added in Synapse v1.69.0. + Args: cached_function: The cached function that will be registered to receive invalidation locally and from other workers. @@ -856,6 +858,8 @@ class ModuleApi: """Invalidate a cache entry of a cached function across workers. The cached function needs to be registered on all workers first with `register_cached_function`. + Added in Synapse v1.69.0. + Args: cached_function: The cached function that needs an invalidation keys: keys of the entry to invalidate, usually matching the arguments of the -- cgit 1.5.1 From e8f30a76caa4394ebb3e77c56df951e3626b3fdd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 11:54:53 +0100 Subject: Fix overflows in /messages backfill calculation (#13936) * Reproduce bug * Compute `least_function` first * Substitute `least_function` with an f-string * Bugfix: avoid overflow Co-authored-by: Eric Eastwood --- changelog.d/13936.feature | 1 + synapse/storage/databases/main/event_federation.py | 82 ++++++++++++++-------- tests/storage/test_event_federation.py | 61 ++++++++++++---- 3 files changed, 103 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13936.feature (limited to 'synapse') diff --git a/changelog.d/13936.feature b/changelog.d/13936.feature new file mode 100644 index 0000000000..d86bf7ed80 --- /dev/null +++ b/changelog.d/13936.feature @@ -0,0 +1 @@ +Exponentially backoff from backfilling the same event over and over. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 17f2fd4458..6b9a629edd 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -73,13 +73,30 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) -BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int( - datetime.timedelta(days=7).total_seconds() -) -BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int( - datetime.timedelta(hours=1).total_seconds() +# Parameters controlling exponential backoff between backfill failures. +# After the first failure to backfill, we wait 2 hours before trying again. If the +# second attempt fails, we wait 4 hours before trying again. If the third attempt fails, +# we wait 8 hours before trying again, ... and so on. +# +# Each successive backoff period is twice as long as the last. However we cap this +# period at a maximum of 2^8 = 256 hours: a little over 10 days. (This is the smallest +# power of 2 which yields a maximum backoff period of at least 7 days---which was the +# original maximum backoff period.) Even when we hit this cap, we will continue to +# make backfill attempts once every 10 days. +BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS = 8 +BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS = int( + datetime.timedelta(hours=1).total_seconds() * 1000 ) +# We need a cap on the power of 2 or else the backoff period +# 2^N * (milliseconds per hour) +# will overflow when calcuated within the database. We ensure overflow does not occur +# by checking that the largest backoff period fits in a 32-bit signed integer. +_LONGEST_BACKOFF_PERIOD_MILLISECONDS = ( + 2**BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS +) * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS +assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1) + # All the info we need while iterating the DAG while backfilling @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -767,7 +784,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # persisted in our database yet (meaning we don't know their depth # specifically). So we need to look for the approximate depth from # the events connected to the current backwards extremeties. - sql = """ + + if isinstance(self.database_engine, PostgresEngine): + least_function = "LEAST" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "MIN" + else: + raise RuntimeError("Unknown database engine") + + sql = f""" SELECT backward_extrem.event_id, event.depth FROM events AS event /** * Get the edge connections from the event_edges table @@ -825,7 +850,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas */ AND ( failed_backfill_attempt_info.event_id IS NULL - OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + ( + (1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */)) + * ? /* step */ + ) ) /** * Sort from highest (closest to the `current_depth`) to the lowest depth @@ -837,22 +865,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas LIMIT ? """ - if isinstance(self.database_engine, PostgresEngine): - least_function = "least" - elif isinstance(self.database_engine, Sqlite3Engine): - least_function = "min" - else: - raise RuntimeError("Unknown database engine") - txn.execute( - sql % (least_function,), + sql, ( room_id, False, current_depth, self._clock.time_msec(), - 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, - 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, limit, ), ) @@ -902,7 +923,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def get_insertion_event_backward_extremities_in_room_txn( txn: LoggingTransaction, room_id: str ) -> List[Tuple[str, int]]: - sql = """ + if isinstance(self.database_engine, PostgresEngine): + least_function = "LEAST" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "MIN" + else: + raise RuntimeError("Unknown database engine") + + sql = f""" SELECT insertion_event_extremity.event_id, event.depth /* We only want insertion events that are also marked as backwards extremities */ @@ -942,7 +970,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas */ AND ( failed_backfill_attempt_info.event_id IS NULL - OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + ( + (1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */)) + * ? /* step */ + ) ) /** * Sort from highest (closest to the `current_depth`) to the lowest depth @@ -954,21 +985,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas LIMIT ? """ - if isinstance(self.database_engine, PostgresEngine): - least_function = "least" - elif isinstance(self.database_engine, Sqlite3Engine): - least_function = "min" - else: - raise RuntimeError("Unknown database engine") - txn.execute( - sql % (least_function,), + sql, ( room_id, current_depth, self._clock.time_msec(), - 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, - 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, limit, ), ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 398f338b66..59b8910907 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -766,9 +766,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] - ) + self.assertEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]) # Try at "A" backfill_points = self.get_success( @@ -814,7 +812,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. - self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"]) + self.assertEqual(backfill_event_ids, ["b6", "2", "b1"]) def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( self, @@ -860,7 +858,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["b3", "b2"]) + self.assertEqual(backfill_event_ids, ["b3", "b2"]) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it @@ -871,7 +869,48 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"]) + self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"]) + + def test_get_backfill_points_in_room_works_after_many_failed_pull_attempts_that_could_naively_overflow( + self, + ) -> None: + """ + A test that reproduces #13929 (Postgres only). + + Test to make sure we can still get backfill points after many failed pull + attempts that cause us to backoff to the limit. Even if the backoff formula + would tell us to wait for more seconds than can be expressed in a 32 bit + signed int. + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + depth_map = setup_info.depth_map + + # Pretend that we have tried and failed 10 times to backfill event b1. + for _ in range(10): + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + + # If the backoff periods grow without limit: + # After the first failed attempt, we would have backed off for 1 << 1 = 2 hours. + # After the second failed attempt we would have backed off for 1 << 2 = 4 hours, + # so after the 10th failed attempt we should backoff for 1 << 10 == 1024 hours. + # Wait 1100 hours just so we have a nice round number. + self.reactor.advance(datetime.timedelta(hours=1100).total_seconds()) + + # 1024 hours in milliseconds is 1024 * 3600000, which exceeds the largest 32 bit + # signed integer. The bug we're reproducing is that this overflow causes an + # error in postgres preventing us from fetching a set of backwards extremities + # to retry fetching. + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) + ) + + # We should aim to fetch all backoff points: b1's latest backoff period has + # expired, and we haven't tried the rest. + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"]) def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: """ @@ -965,9 +1004,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["insertion_eventB", "insertion_eventA"] - ) + self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"]) # Try at "insertion_eventA" backfill_points = self.get_success( @@ -1011,7 +1048,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. - self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + self.assertEqual(backfill_event_ids, ["insertion_eventB"]) def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( self, @@ -1069,7 +1106,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, []) + self.assertEqual(backfill_event_ids, []) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it @@ -1083,7 +1120,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) + self.assertEqual(backfill_event_ids, ["insertion_eventA"]) @attr.s -- cgit 1.5.1 From 3dfc4a08dc2e77178f2c2af68dc14b32da2d8b8f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 13:15:32 +0100 Subject: Fix performance regression in `get_users_in_room` (#13972) Fixes #13942. Introduced in #13575. Basically, let's only get the ordered set of hosts out of the DB if we need an ordered set of hosts. Since we split the function up the caching won't be as good, but I think it will still be fine as e.g. multiple backfill requests for the same room will hit the cache. --- changelog.d/13972.bugfix | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/room.py | 4 +- synapse/storage/controllers/state.py | 30 ++++--- synapse/storage/databases/main/roommember.py | 129 +++++++++++++++------------ 5 files changed, 98 insertions(+), 70 deletions(-) create mode 100644 changelog.d/13972.bugfix (limited to 'synapse') diff --git a/changelog.d/13972.bugfix b/changelog.d/13972.bugfix new file mode 100644 index 0000000000..4c1e19ef8c --- /dev/null +++ b/changelog.d/13972.bugfix @@ -0,0 +1 @@ +Fix a performance regression in the `get_users_in_room` database query. Introduced in v1.67.0. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b866258298..986ffed3d5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -412,7 +412,9 @@ class FederationHandler: # First we try hosts that are already in the room. # TODO: HEURISTIC ALERT. likely_domains = ( - await self._storage_controllers.state.get_current_hosts_in_room(room_id) + await self._storage_controllers.state.get_current_hosts_in_room_ordered( + room_id + ) ) async def try_backfill(domains: Collection[str]) -> bool: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b220238e55..57ab05ad25 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1540,7 +1540,9 @@ class TimestampLookupHandler: ) likely_domains = ( - await self._storage_controllers.state.get_current_hosts_in_room(room_id) + await self._storage_controllers.state.get_current_hosts_in_room_ordered( + room_id + ) ) # Loop through each homeserver candidate until we get a succesful response diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index bb60130afe..2b31ce54bb 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -23,7 +23,7 @@ from typing import ( List, Mapping, Optional, - Sequence, + Set, Tuple, ) @@ -529,7 +529,18 @@ class StateStorageController: ) return state_map.get(key) - async def get_current_hosts_in_room(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state. + + Blocks until we have full state for the given room. This only happens for rooms + with partial state. + """ + + await self._partial_state_room_tracker.await_full_state(room_id) + + return await self.stores.main.get_current_hosts_in_room(room_id) + + async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms @@ -542,11 +553,11 @@ class StateStorageController: await self._partial_state_room_tracker.await_full_state(room_id) - return await self.stores.main.get_current_hosts_in_room(room_id) + return await self.stores.main.get_current_hosts_in_room_ordered(room_id) async def get_current_hosts_in_room_or_partial_state_approximation( self, room_id: str - ) -> Sequence[str]: + ) -> Collection[str]: """Get approximation of current hosts in room based on current state. For rooms with full state, this is equivalent to `get_current_hosts_in_room`, @@ -566,14 +577,9 @@ class StateStorageController: ) hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id) - hosts_from_state_set = set(hosts_from_state) - - # First take the list of hosts based on the current state. - # For rooms with partial state, this will be missing most hosts. - hosts = list(hosts_from_state) - # Then add in the list of hosts in the room at the time we joined. - # This will be an empty list for rooms with full state. - hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set) + + hosts = set(hosts_at_join) + hosts.update(hosts_from_state) return hosts diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 982e1f08e3..2337289d88 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -146,42 +146,37 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) async def get_users_in_room(self, room_id: str) -> List[str]: - """ - Returns a list of users in the room sorted by longest in the room first - (aka. with the lowest depth). This is done to match the sort in - `get_current_hosts_in_room()` and so we can re-use the cache but it's - not horrible to have here either. - - Uses `m.room.member`s in the room state at the current forward extremities to - determine which users are in the room. + """Returns a list of users in the room. Will return inaccurate results for rooms with partial state, since the state for the forward extremities of those rooms will exclude most members. We may also calculate room state incorrectly for such rooms and believe that a member is or is not in the room when the opposite is true. """ - return await self.db_pool.runInteraction( - "get_users_in_room", self.get_users_in_room_txn, room_id + return await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "room_id": room_id, + "membership": Membership.JOIN, + }, + retcol="state_key", + desc="get_users_in_room", ) def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: - """ - Returns a list of users in the room sorted by longest in the room first - (aka. with the lowest depth). This is done to match the sort in - `get_current_hosts_in_room()` and so we can re-use the cache but it's - not horrible to have here either. - """ - sql = """ - SELECT c.state_key FROM current_state_events as c - /* Get the depth of the event from the events table */ - INNER JOIN events AS e USING (event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ? - /* Sorted by lowest depth first */ - ORDER BY e.depth ASC; - """ + """Returns a list of users in the room.""" - txn.execute(sql, (room_id, Membership.JOIN)) - return [r[0] for r in txn] + return self.db_pool.simple_select_onecol_txn( + txn, + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "room_id": room_id, + "membership": Membership.JOIN, + }, + retcol="state_key", + ) @cached() def get_user_in_room_with_profile( @@ -931,7 +926,44 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state.""" + + # First we check if we already have `get_users_in_room` in the cache, as + # we can just calculate result from that + users = self.get_users_in_room.cache.get_immediate( + (room_id,), None, update_metrics=False + ) + if users is not None: + return {get_domain_from_id(u) for u in users} + + if isinstance(self.database_engine, Sqlite3Engine): + # If we're using SQLite then let's just always use + # `get_users_in_room` rather than funky SQL. + users = await self.get_users_in_room(room_id) + return {get_domain_from_id(u) for u in users} + + # For PostgreSQL we can use a regex to pull out the domains from the + # joined users in `current_state_events` via regex. + + def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ + SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$') + FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND room_id = ? + """ + txn.execute(sql, (room_id,)) + return {d for d, in txn} + + return await self.db_pool.runInteraction( + "get_current_hosts_in_room", get_current_hosts_in_room_txn + ) + + @cached(iterable=True, max_entries=10000) + async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: """ Get current hosts in room based on current state. @@ -939,48 +971,33 @@ class RoomMemberWorkerStore(EventsWorkerStore): longest is good because they're most likely to have anything we ask about. - Uses `m.room.member`s in the room state at the current forward extremities to - determine which hosts are in the room. + For SQLite the returned list is not ordered, as SQLite doesn't support + the appropriate SQL. - Will return inaccurate results for rooms with partial state, since the state for - the forward extremities of those rooms will exclude most members. We may also - calculate room state incorrectly for such rooms and believe that a host is or - is not in the room when the opposite is true. + Uses `m.room.member`s in the room state at the current forward + extremities to determine which hosts are in the room. + + Will return inaccurate results for rooms with partial state, since the + state for the forward extremities of those rooms will exclude most + members. We may also calculate room state incorrectly for such rooms and + believe that a host is or is not in the room when the opposite is true. Returns: Returns a list of servers sorted by longest in the room first. (aka. sorted by join with the lowest depth first). """ - # First we check if we already have `get_users_in_room` in the cache, as - # we can just calculate result from that - users = self.get_users_in_room.cache.get_immediate( - (room_id,), None, update_metrics=False - ) - if users is None and isinstance(self.database_engine, Sqlite3Engine): + if isinstance(self.database_engine, Sqlite3Engine): # If we're using SQLite then let's just always use # `get_users_in_room` rather than funky SQL. - users = await self.get_users_in_room(room_id) - if users is not None: - # Because `users` is sorted from lowest -> highest depth, the list - # of domains will also be sorted that way. - domains: List[str] = [] - # We use a `Set` just for fast lookups - domain_set: Set[str] = set() - for u in users: - if ":" not in u: - continue - domain = get_domain_from_id(u) - if domain not in domain_set: - domain_set.add(domain) - domains.append(domain) - return domains + domains = await self.get_current_hosts_in_room(room_id) + return list(domains) # For PostgreSQL we can use a regex to pull out the domains from the # joined users in `current_state_events` via regex. - def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]: + def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: # Returns a list of servers currently joined in the room sorted by # longest in the room first (aka. with the lowest depth). The # heuristic of sorting by servers who have been in the room the @@ -1008,7 +1025,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [d for d, in txn if d is not None] return await self.db_pool.runInteraction( - "get_current_hosts_in_room", get_current_hosts_in_room_txn + "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn ) async def get_joined_hosts( -- cgit 1.5.1 From 5507bfa769e61f5ef507c6172b8e798a87ac84b1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 14:23:37 +0100 Subject: Discourage automatic replies to Synapse's emails (#13957) Co-authored-by: Patrick Cloke --- changelog.d/13957.feature | 1 + synapse/handlers/send_email.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) create mode 100644 changelog.d/13957.feature (limited to 'synapse') diff --git a/changelog.d/13957.feature b/changelog.d/13957.feature new file mode 100644 index 0000000000..4080147357 --- /dev/null +++ b/changelog.d/13957.feature @@ -0,0 +1 @@ +Ask mail servers receiving emails from Synapse to not send automatic reply (e.g. out-of-office responses). diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index e2844799e8..804cc6e81e 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -187,6 +187,19 @@ class SendEmailHandler: multipart_msg["To"] = email_address multipart_msg["Date"] = email.utils.formatdate() multipart_msg["Message-ID"] = email.utils.make_msgid() + # Discourage automatic responses to Synapse's emails. + # Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted" + # header is present with any value other than "no". See + # https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1 + multipart_msg["Auto-Submitted"] = "auto-generated" + # Also include a Microsoft-Exchange specific header: + # https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1 + # which suggests it can take the value "All" to "suppress all auto-replies", + # or a comma separated list of auto-reply classes to suppress. + # The following stack overflow question has a little more context: + # https://stackoverflow.com/a/25324691/5252017 + # https://stackoverflow.com/a/61646381/5252017 + multipart_msg["X-Auto-Response-Suppress"] = "All" multipart_msg.attach(text_part) multipart_msg.attach(html_part) -- cgit 1.5.1 From 285b9e9b6c3558718e7d4f513062e277948ac35d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 14:27:00 +0100 Subject: Speed up calculating push actions in large rooms (#13973) We move the expensive check of visibility to after calculating push actions, avoiding the expensive check for users who won't get pushed anyway. I think this should have a big impact on rooms with large numbers of local users that have pushed disabled. --- changelog.d/13973.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 25 ++++++---- tests/push/test_push_rule_evaluator.py | 82 +++++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 changelog.d/13973.misc (limited to 'synapse') diff --git a/changelog.d/13973.misc b/changelog.d/13973.misc new file mode 100644 index 0000000000..58150a2b35 --- /dev/null +++ b/changelog.d/13973.misc @@ -0,0 +1 @@ +Speed up calculating push actions in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 60f3129005..7bfe380543 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -303,20 +303,10 @@ class BulkPushRuleEvaluator: event.room_id, users ) - # This is a check for the case where user joins a room without being - # allowed to see history, and then the server receives a delayed event - # from before the user joined, which they should not be pushed for - uids_with_visibility = await filter_event_for_clients_with_state( - self.store, users, event, context - ) - for uid, rules in rules_by_user.items(): if event.sender == uid: continue - if uid not in uids_with_visibility: - continue - display_name = None profile = profiles.get(uid) if profile: @@ -342,6 +332,21 @@ class BulkPushRuleEvaluator: # Push rules say we should notify the user of this event actions_by_user[uid] = actions + # This is a check for the case where user joins a room without being + # allowed to see history, and then the server receives a delayed event + # from before the user joined, which they should not be pushed for + # + # We do this *after* calculating the push actions as a) its unlikely + # that we'll filter anyone out and b) for large rooms its likely that + # most users will have push disabled and so the set of users to check is + # much smaller. + uids_with_visibility = await filter_event_for_clients_with_state( + self.store, actions_by_user.keys(), event, context + ) + + for user_id in set(actions_by_user).difference(uids_with_visibility): + actions_by_user.pop(user_id, None) + # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index b8308cbc05..8804f0e0d3 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -19,17 +19,18 @@ import frozendict from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.events import FrozenEvent from synapse.push.bulk_push_rule_evaluator import _flatten_dict from synapse.push.httppusher import tweaks_for_actions +from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest @@ -437,3 +438,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): ) self.assertEqual(len(users_with_push_actions), 0) + + +class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.main_store = homeserver.get_datastores().main + + self.user_id1 = self.register_user("user1", "password") + self.tok1 = self.login(self.user_id1, "password") + self.user_id2 = self.register_user("user2", "password") + self.tok2 = self.login(self.user_id2, "password") + + self.room_id = self.helper.create_room_as(tok=self.tok1) + + # We want to test history visibility works correctly. + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + {"history_visibility": HistoryVisibility.JOINED}, + tok=self.tok1, + ) + + def get_notif_count(self, user_id: str) -> int: + return self.get_success( + self.main_store.db_pool.simple_select_one_onecol( + table="event_push_actions", + keyvalues={"user_id": user_id}, + retcol="COALESCE(SUM(notif), 0)", + desc="get_staging_notif_count", + ) + ) + + def test_plain_message(self) -> None: + """Test that sending a normal message in a room will trigger a + notification + """ + + # Have user2 join the room and cle + self.helper.join(self.room_id, self.user_id2, tok=self.tok2) + + # They start off with no notifications, but get them when messages are + # sent. + self.assertEqual(self.get_notif_count(self.user_id2), 0) + + user1 = UserID.from_string(self.user_id1) + self.create_and_send_event(self.room_id, user1) + + self.assertEqual(self.get_notif_count(self.user_id2), 1) + + def test_delayed_message(self) -> None: + """Test that a delayed message that was from before a user joined + doesn't cause a notification for the joined user. + """ + user1 = UserID.from_string(self.user_id1) + + # Send a message before user2 joins + event_id1 = self.create_and_send_event(self.room_id, user1) + + # Have user2 join the room + self.helper.join(self.room_id, self.user_id2, tok=self.tok2) + + # They start off with no notifications + self.assertEqual(self.get_notif_count(self.user_id2), 0) + + # Send another message that references the event before the join to + # simulate a "delayed" event + self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1]) + + # user2 should not be notified about it, because they can't see it. + self.assertEqual(self.get_notif_count(self.user_id2), 0) -- cgit 1.5.1 From 6d543d6d9f56e39199b7e460d0081b02d61f12be Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 16:34:47 +0100 Subject: Update mypy and mypy-zope (#13925) * Update mypy and mypy-zope * Unignore assigning to LogRecord attributes Presumably https://github.com/python/typeshed/pull/8064 makes this ok Cherry-picked from #13521 * Remove unused ignores due to mypy ParamSpec fixes https://github.com/python/mypy/pull/12668 Cherry-picked from #13521 * Remove additional unused ignores * Fix new mypy complaints related to `assertGreater` Presumably due to https://github.com/python/typeshed/pull/8077 * Changelog * Reword changelog Co-authored-by: Patrick Cloke Co-authored-by: Patrick Cloke --- changelog.d/13925.misc | 1 + poetry.lock | 59 +++++++++++++++--------------- scripts-dev/check_pydantic_models.py | 5 +-- synapse/app/_base.py | 4 +- synapse/logging/context.py | 20 +++++----- synapse/logging/opentracing.py | 4 +- synapse/storage/database.py | 22 +++-------- synapse/storage/databases/main/search.py | 2 +- tests/storage/test_monthly_active_users.py | 6 +++ tests/utils.py | 4 +- 10 files changed, 60 insertions(+), 67 deletions(-) create mode 100644 changelog.d/13925.misc (limited to 'synapse') diff --git a/changelog.d/13925.misc b/changelog.d/13925.misc new file mode 100644 index 0000000000..f490ab122e --- /dev/null +++ b/changelog.d/13925.misc @@ -0,0 +1 @@ +Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/poetry.lock b/poetry.lock index 0f6d1cfa69..63ef8573a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,11 +573,11 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.950" +version = "0.981" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -600,14 +600,14 @@ python-versions = "*" [[package]] name = "mypy-zope" -version = "0.3.7" +version = "0.3.11" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" [package.dependencies] -mypy = "0.950" +mypy = "0.981" "zope.interface" = "*" "zope.schema" = "*" @@ -2162,37 +2162,38 @@ msgpack = [ {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, ] mypy = [ - {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, - {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, - {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, - {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, - {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, - {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, - {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, - {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, - {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, - {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, - {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, - {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, - {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, - {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, - {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, - {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, - {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, - {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, - {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, + {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, + {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, + {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, + {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, + {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, + {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, + {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, + {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, + {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, + {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, + {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, + {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, + {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, + {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, + {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, + {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, + {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, + {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] mypy-zope = [ - {file = "mypy-zope-0.3.7.tar.gz", hash = "sha256:9da171e78e8ef7ac8922c86af1a62f1b7f3244f121020bd94a2246bc3f33c605"}, - {file = "mypy_zope-0.3.7-py3-none-any.whl", hash = "sha256:9c7637d066e4d1bafa0651abc091c752009769098043b236446e6725be2bc9c2"}, + {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, + {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, ] netaddr = [ {file = "netaddr-0.8.0-py2.py3-none-any.whl", hash = "sha256:9666d0232c32d2656e5e5f8d735f58fd6c7457ce52fc21c98d45f2af78f990ac"}, diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index d0fb811bdb..9f2b7ded5b 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -88,10 +88,9 @@ def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 - if "strict" not in kwargs: # type: ignore[attr-defined] + if "strict" not in kwargs: raise MissingStrictInConstrainedTypeException(factory.__name__) - if not kwargs["strict"]: # type: ignore[index] + if not kwargs["strict"]: raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9a24bed0a0..000912e86e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -98,9 +98,7 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] + _sighup_callbacks.append((func, args, kwargs)) def start_worker_reactor( diff --git a/synapse/logging/context.py b/synapse/logging/context.py index fd9cb97920..6a08ffed64 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -586,7 +586,7 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - record.request = self._default_request # type: ignore + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for @@ -594,21 +594,21 @@ class LoggingContextFilter(logging.Filter): if context is not None: # Logging is interested in the request ID. Note that for backwards # compatibility this is stored as the "request" on the record. - record.request = str(context) # type: ignore + record.request = str(context) # Add some data from the HTTP request. request = context.request if request is None: return True - record.ip_address = request.ip_address # type: ignore - record.site_tag = request.site_tag # type: ignore - record.requester = request.requester # type: ignore - record.authenticated_entity = request.authenticated_entity # type: ignore - record.method = request.method # type: ignore - record.url = request.url # type: ignore - record.protocol = request.protocol # type: ignore - record.user_agent = request.user_agent # type: ignore + record.ip_address = request.ip_address + record.site_tag = request.site_tag + record.requester = request.requester + record.authenticated_entity = request.authenticated_entity + record.method = request.method + record.url = request.url + record.protocol = request.protocol + record.user_agent = request.user_agent return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ca2735dd6d..8ce5a2a338 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -992,9 +992,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # FIXME: We could update this to handle any type of function by ignoring the # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): # type: ignore[index] + for i, arg in enumerate(args[1:], start=1): set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) - set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bb28ded1b5..a252f8eaa0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -290,8 +290,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.after_callbacks.append((callback, args, kwargs)) def async_call_after( self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs @@ -312,8 +311,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.async_after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.async_after_callbacks.append((callback, args, kwargs)) def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs @@ -331,8 +329,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.exception_callbacks.append((callback, args, kwargs)) def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -421,10 +418,7 @@ class LoggingTransaction: sql = self.database_engine.convert_param_style(sql) if args: try: - # The type-ignore should be redundant once mypy releases a version with - # https://github.com/python/mypy/pull/12668. (`args` might be empty, - # (but we'll catch the index error if so.) - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index] + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) except Exception: # Don't let logging failures stop SQL from working pass @@ -655,9 +649,7 @@ class DatabasePool: # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see - # https://github.com/python/mypy/pull/12668 - for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] + for i, arg in enumerate(args): if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -665,9 +657,7 @@ class DatabasePool: i, func, ) - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see - # https://github.com/python/mypy/pull/12668 - for name, val in kwargs.items(): # type: ignore[attr-defined] + for name, val in kwargs.items(): if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f6e24b68d2..1b79acf955 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -641,7 +641,7 @@ class SearchStore(SearchBackgroundUpdateStore): raise Exception("Unrecognized database engine") # mypy expects to append only a `str`, not an `int` - args.append(limit) # type: ignore[arg-type] + args.append(limit) results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e8b4a5644b..3da8221109 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -96,8 +96,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. + # Check that timestamp really is an int. + assert timestamp is not None self.assertGreater(timestamp, 0) timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + assert timestamp is not None self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table @@ -166,9 +170,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(user_id2)) result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + assert result is not None self.assertGreater(result, 0) result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) + assert result is not None self.assertNotEqual(result, 0) @override_config({"max_mau_value": 5}) diff --git a/tests/utils.py b/tests/utils.py index 65db437697..045a8b5fa7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -270,9 +270,7 @@ class MockClock: *args: P.args, **kwargs: P.kwargs, ) -> None: - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: -- cgit 1.5.1 From 8e52cb0bce4c4e42a0f151f16e51529b7aba8f7d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 16:37:48 +0100 Subject: Revert "Update mypy and mypy-zope (#13925)" This reverts commit 6d543d6d9f56e39199b7e460d0081b02d61f12be. --- changelog.d/13925.misc | 1 - poetry.lock | 59 +++++++++++++++--------------- scripts-dev/check_pydantic_models.py | 5 ++- synapse/app/_base.py | 4 +- synapse/logging/context.py | 20 +++++----- synapse/logging/opentracing.py | 4 +- synapse/storage/database.py | 22 ++++++++--- synapse/storage/databases/main/search.py | 2 +- tests/storage/test_monthly_active_users.py | 6 --- tests/utils.py | 4 +- 10 files changed, 67 insertions(+), 60 deletions(-) delete mode 100644 changelog.d/13925.misc (limited to 'synapse') diff --git a/changelog.d/13925.misc b/changelog.d/13925.misc deleted file mode 100644 index f490ab122e..0000000000 --- a/changelog.d/13925.misc +++ /dev/null @@ -1 +0,0 @@ -Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/poetry.lock b/poetry.lock index 63ef8573a0..0f6d1cfa69 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,11 +573,11 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.981" +version = "0.950" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.6" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -600,14 +600,14 @@ python-versions = "*" [[package]] name = "mypy-zope" -version = "0.3.11" +version = "0.3.7" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" [package.dependencies] -mypy = "0.981" +mypy = "0.950" "zope.interface" = "*" "zope.schema" = "*" @@ -2162,38 +2162,37 @@ msgpack = [ {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, ] mypy = [ - {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, - {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, - {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, - {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, - {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, - {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, - {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, - {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, - {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, - {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, - {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, - {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, - {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, - {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, - {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, - {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, - {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, - {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, - {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, - {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, + {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, + {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, + {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, + {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, + {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, + {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, + {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, + {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, + {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, + {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, + {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, + {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, + {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, + {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, + {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, + {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, + {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, + {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, + {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, + {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, + {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, + {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, + {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] mypy-zope = [ - {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, - {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, + {file = "mypy-zope-0.3.7.tar.gz", hash = "sha256:9da171e78e8ef7ac8922c86af1a62f1b7f3244f121020bd94a2246bc3f33c605"}, + {file = "mypy_zope-0.3.7-py3-none-any.whl", hash = "sha256:9c7637d066e4d1bafa0651abc091c752009769098043b236446e6725be2bc9c2"}, ] netaddr = [ {file = "netaddr-0.8.0-py2.py3-none-any.whl", hash = "sha256:9666d0232c32d2656e5e5f8d735f58fd6c7457ce52fc21c98d45f2af78f990ac"}, diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 9f2b7ded5b..d0fb811bdb 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -88,9 +88,10 @@ def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - if "strict" not in kwargs: + # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 + if "strict" not in kwargs: # type: ignore[attr-defined] raise MissingStrictInConstrainedTypeException(factory.__name__) - if not kwargs["strict"]: + if not kwargs["strict"]: # type: ignore[index] raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 000912e86e..9a24bed0a0 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -98,7 +98,9 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - _sighup_callbacks.append((func, args, kwargs)) + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] def start_worker_reactor( diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 6a08ffed64..fd9cb97920 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -586,7 +586,7 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - record.request = self._default_request + record.request = self._default_request # type: ignore # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for @@ -594,21 +594,21 @@ class LoggingContextFilter(logging.Filter): if context is not None: # Logging is interested in the request ID. Note that for backwards # compatibility this is stored as the "request" on the record. - record.request = str(context) + record.request = str(context) # type: ignore # Add some data from the HTTP request. request = context.request if request is None: return True - record.ip_address = request.ip_address - record.site_tag = request.site_tag - record.requester = request.requester - record.authenticated_entity = request.authenticated_entity - record.method = request.method - record.url = request.url - record.protocol = request.protocol - record.user_agent = request.user_agent + record.ip_address = request.ip_address # type: ignore + record.site_tag = request.site_tag # type: ignore + record.requester = request.requester # type: ignore + record.authenticated_entity = request.authenticated_entity # type: ignore + record.method = request.method # type: ignore + record.url = request.url # type: ignore + record.protocol = request.protocol # type: ignore + record.user_agent = request.user_agent # type: ignore return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ce5a2a338..ca2735dd6d 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -992,9 +992,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # FIXME: We could update this to handle any type of function by ignoring the # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): + for i, arg in enumerate(args[1:], start=1): # type: ignore[index] set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) - set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a252f8eaa0..bb28ded1b5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -290,7 +290,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - self.after_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def async_call_after( self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs @@ -311,7 +312,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.async_after_callbacks is not None - self.async_after_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs @@ -329,7 +331,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - self.exception_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -418,7 +421,10 @@ class LoggingTransaction: sql = self.database_engine.convert_param_style(sql) if args: try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + # The type-ignore should be redundant once mypy releases a version with + # https://github.com/python/mypy/pull/12668. (`args` might be empty, + # (but we'll catch the index error if so.) + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index] except Exception: # Don't let logging failures stop SQL from working pass @@ -649,7 +655,9 @@ class DatabasePool: # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - for i, arg in enumerate(args): + # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see + # https://github.com/python/mypy/pull/12668 + for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -657,7 +665,9 @@ class DatabasePool: i, func, ) - for name, val in kwargs.items(): + # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see + # https://github.com/python/mypy/pull/12668 + for name, val in kwargs.items(): # type: ignore[attr-defined] if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1b79acf955..f6e24b68d2 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -641,7 +641,7 @@ class SearchStore(SearchBackgroundUpdateStore): raise Exception("Unrecognized database engine") # mypy expects to append only a `str`, not an `int` - args.append(limit) + args.append(limit) # type: ignore[arg-type] results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 3da8221109..e8b4a5644b 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -96,12 +96,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) - # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. - # Check that timestamp really is an int. - assert timestamp is not None self.assertGreater(timestamp, 0) timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) - assert timestamp is not None self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table @@ -170,11 +166,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(user_id2)) result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) - assert result is not None self.assertGreater(result, 0) result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) - assert result is not None self.assertNotEqual(result, 0) @override_config({"max_mau_value": 5}) diff --git a/tests/utils.py b/tests/utils.py index 045a8b5fa7..65db437697 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -270,7 +270,9 @@ class MockClock: *args: P.args, **kwargs: P.kwargs, ) -> None: - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: -- cgit 1.5.1 From 285d72556bb3c36f075b336b2bdd6acb08391ad5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 17:36:28 +0100 Subject: Update mypy and mypy-zope, attempt 3 (#13993) Co-authored-by: Patrick Cloke --- changelog.d/13925.misc | 1 + changelog.d/13993.misc | 1 + poetry.lock | 59 +++++++++++++++--------------- scripts-dev/check_pydantic_models.py | 5 +-- synapse/app/_base.py | 4 +- synapse/logging/context.py | 20 +++++----- synapse/logging/opentracing.py | 4 +- synapse/storage/database.py | 22 +++-------- synapse/storage/databases/main/search.py | 2 +- tests/storage/test_monthly_active_users.py | 7 +++- tests/utils.py | 4 +- 11 files changed, 61 insertions(+), 68 deletions(-) create mode 100644 changelog.d/13925.misc create mode 100644 changelog.d/13993.misc (limited to 'synapse') diff --git a/changelog.d/13925.misc b/changelog.d/13925.misc new file mode 100644 index 0000000000..f490ab122e --- /dev/null +++ b/changelog.d/13925.misc @@ -0,0 +1 @@ +Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/changelog.d/13993.misc b/changelog.d/13993.misc new file mode 100644 index 0000000000..f490ab122e --- /dev/null +++ b/changelog.d/13993.misc @@ -0,0 +1 @@ +Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/poetry.lock b/poetry.lock index 0f6d1cfa69..63ef8573a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,11 +573,11 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.950" +version = "0.981" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -600,14 +600,14 @@ python-versions = "*" [[package]] name = "mypy-zope" -version = "0.3.7" +version = "0.3.11" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" [package.dependencies] -mypy = "0.950" +mypy = "0.981" "zope.interface" = "*" "zope.schema" = "*" @@ -2162,37 +2162,38 @@ msgpack = [ {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, ] mypy = [ - {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, - {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, - {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, - {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, - {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, - {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, - {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, - {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, - {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, - {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, - {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, - {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, - {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, - {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, - {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, - {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, - {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, - {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, - {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, + {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, + {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, + {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, + {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, + {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, + {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, + {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, + {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, + {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, + {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, + {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, + {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, + {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, + {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, + {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, + {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, + {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, + {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] mypy-zope = [ - {file = "mypy-zope-0.3.7.tar.gz", hash = "sha256:9da171e78e8ef7ac8922c86af1a62f1b7f3244f121020bd94a2246bc3f33c605"}, - {file = "mypy_zope-0.3.7-py3-none-any.whl", hash = "sha256:9c7637d066e4d1bafa0651abc091c752009769098043b236446e6725be2bc9c2"}, + {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, + {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, ] netaddr = [ {file = "netaddr-0.8.0-py2.py3-none-any.whl", hash = "sha256:9666d0232c32d2656e5e5f8d735f58fd6c7457ce52fc21c98d45f2af78f990ac"}, diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index d0fb811bdb..9f2b7ded5b 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -88,10 +88,9 @@ def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 - if "strict" not in kwargs: # type: ignore[attr-defined] + if "strict" not in kwargs: raise MissingStrictInConstrainedTypeException(factory.__name__) - if not kwargs["strict"]: # type: ignore[index] + if not kwargs["strict"]: raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9a24bed0a0..000912e86e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -98,9 +98,7 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] + _sighup_callbacks.append((func, args, kwargs)) def start_worker_reactor( diff --git a/synapse/logging/context.py b/synapse/logging/context.py index fd9cb97920..6a08ffed64 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -586,7 +586,7 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - record.request = self._default_request # type: ignore + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for @@ -594,21 +594,21 @@ class LoggingContextFilter(logging.Filter): if context is not None: # Logging is interested in the request ID. Note that for backwards # compatibility this is stored as the "request" on the record. - record.request = str(context) # type: ignore + record.request = str(context) # Add some data from the HTTP request. request = context.request if request is None: return True - record.ip_address = request.ip_address # type: ignore - record.site_tag = request.site_tag # type: ignore - record.requester = request.requester # type: ignore - record.authenticated_entity = request.authenticated_entity # type: ignore - record.method = request.method # type: ignore - record.url = request.url # type: ignore - record.protocol = request.protocol # type: ignore - record.user_agent = request.user_agent # type: ignore + record.ip_address = request.ip_address + record.site_tag = request.site_tag + record.requester = request.requester + record.authenticated_entity = request.authenticated_entity + record.method = request.method + record.url = request.url + record.protocol = request.protocol + record.user_agent = request.user_agent return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ca2735dd6d..8ce5a2a338 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -992,9 +992,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # FIXME: We could update this to handle any type of function by ignoring the # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): # type: ignore[index] + for i, arg in enumerate(args[1:], start=1): set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) - set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bb28ded1b5..a252f8eaa0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -290,8 +290,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.after_callbacks.append((callback, args, kwargs)) def async_call_after( self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs @@ -312,8 +311,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.async_after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.async_after_callbacks.append((callback, args, kwargs)) def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs @@ -331,8 +329,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.exception_callbacks.append((callback, args, kwargs)) def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -421,10 +418,7 @@ class LoggingTransaction: sql = self.database_engine.convert_param_style(sql) if args: try: - # The type-ignore should be redundant once mypy releases a version with - # https://github.com/python/mypy/pull/12668. (`args` might be empty, - # (but we'll catch the index error if so.) - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index] + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) except Exception: # Don't let logging failures stop SQL from working pass @@ -655,9 +649,7 @@ class DatabasePool: # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see - # https://github.com/python/mypy/pull/12668 - for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] + for i, arg in enumerate(args): if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -665,9 +657,7 @@ class DatabasePool: i, func, ) - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see - # https://github.com/python/mypy/pull/12668 - for name, val in kwargs.items(): # type: ignore[attr-defined] + for name, val in kwargs.items(): if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f6e24b68d2..1b79acf955 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -641,7 +641,7 @@ class SearchStore(SearchBackgroundUpdateStore): raise Exception("Unrecognized database engine") # mypy expects to append only a `str`, not an `int` - args.append(limit) # type: ignore[arg-type] + args.append(limit) results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e8b4a5644b..c55c4db970 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -96,8 +96,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. + # Check that timestamp really is an int. + assert timestamp is not None self.assertGreater(timestamp, 0) timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + assert timestamp is not None self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table @@ -166,10 +170,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(user_id2)) result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + assert result is not None self.assertGreater(result, 0) result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) - self.assertNotEqual(result, 0) + self.assertIsNone(result) @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): diff --git a/tests/utils.py b/tests/utils.py index 65db437697..045a8b5fa7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -270,9 +270,7 @@ class MockClock: *args: P.args, **kwargs: P.kwargs, ) -> None: - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: -- cgit 1.5.1 From 535f8c8f7d64d4058500a5988278fd3026645164 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 17:40:33 +0100 Subject: Skip filtering during push if there are no push actions (#13992) --- changelog.d/13992.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 5 +++++ synapse/visibility.py | 4 ++++ tests/rest/client/test_rooms.py | 4 ++-- 4 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13992.misc (limited to 'synapse') diff --git a/changelog.d/13992.misc b/changelog.d/13992.misc new file mode 100644 index 0000000000..58150a2b35 --- /dev/null +++ b/changelog.d/13992.misc @@ -0,0 +1 @@ +Speed up calculating push actions in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7bfe380543..4270438918 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -332,6 +332,11 @@ class BulkPushRuleEvaluator: # Push rules say we should notify the user of this event actions_by_user[uid] = actions + # If there aren't any actions then we can skip the rest of the + # processing. + if not actions_by_user: + return + # This is a check for the case where user joins a room without being # allowed to see history, and then the server receives a delayed event # from before the user joined, which they should not be pushed for diff --git a/synapse/visibility.py b/synapse/visibility.py index c810a05907..c4048d2477 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -162,6 +162,10 @@ async def filter_event_for_clients_with_state( if event.internal_metadata.is_soft_failed(): return [] + # Fast path if we don't have any user IDs to check. + if not user_ids: + return () + # Make a set for all user IDs that haven't been filtered out by a check. allowed_user_ids = set(user_ids) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e281aef779..7f8cf4fab0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(38, channel.resource_usage.db_txn_count) + self.assertEqual(37, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From ad4c14e4b0c44d6a8ee42e760d7e1fe1755559a2 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 30 Sep 2022 14:40:18 -0500 Subject: Clarifications in user directory for users who share rooms tracking (#13966) Spawned while working on [`get_users_in_room` mis-uses](https://github.com/matrix-org/synapse/pull/13958#discussion_r984074897) and thinking we could use `get_local_users_in_room` here but we can't. From first glance, it seemed like this was only using local users from all of the `is_mine_id(user_id)` checks but I see that it does actually use remote users. Just making things a little more clear here what it does and mentions remote users so maybe that will be more obvious in the future. --- changelog.d/13966.misc | 1 + synapse/handlers/user_directory.py | 36 ++++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 12 deletions(-) create mode 100644 changelog.d/13966.misc (limited to 'synapse') diff --git a/changelog.d/13966.misc b/changelog.d/13966.misc new file mode 100644 index 0000000000..b54ad5c776 --- /dev/null +++ b/changelog.d/13966.misc @@ -0,0 +1 @@ +Refactor language in user directory `_track_user_joined_room` code to make it more clear that we use both local and remote users. diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 8c3c52e1ca..3610b6bf78 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple import synapse.metrics from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership @@ -379,7 +379,7 @@ class UserDirectoryHandler(StateDeltasHandler): user_id, event.content.get("displayname"), event.content.get("avatar_url") ) - async def _track_user_joined_room(self, room_id: str, user_id: str) -> None: + async def _track_user_joined_room(self, room_id: str, joining_user_id: str) -> None: """Someone's just joined a room. Update `users_in_public_rooms` or `users_who_share_private_rooms` as appropriate. @@ -390,32 +390,44 @@ class UserDirectoryHandler(StateDeltasHandler): room_id ) if is_public: - await self.store.add_users_in_public_rooms(room_id, (user_id,)) + await self.store.add_users_in_public_rooms(room_id, (joining_user_id,)) else: users_in_room = await self.store.get_users_in_room(room_id) other_users_in_room = [ other for other in users_in_room - if other != user_id + if other != joining_user_id and ( + # We can't apply any special rules to remote users so + # they're always included not self.is_mine_id(other) + # Check the special rules whether the local user should be + # included in the user directory or await self.store.should_include_local_user_in_dir(other) ) ] - to_insert = set() + updates_to_users_who_share_rooms: Set[Tuple[str, str]] = set() - # First, if they're our user then we need to update for every user - if self.is_mine_id(user_id): + # First, if the joining user is our local user then we need an + # update for every other user in the room. + if self.is_mine_id(joining_user_id): for other_user_id in other_users_in_room: - to_insert.add((user_id, other_user_id)) + updates_to_users_who_share_rooms.add( + (joining_user_id, other_user_id) + ) - # Next we need to update for every local user in the room + # Next, we need an update for every other local user in the room + # that they now share a room with the joining user. for other_user_id in other_users_in_room: if self.is_mine_id(other_user_id): - to_insert.add((other_user_id, user_id)) + updates_to_users_who_share_rooms.add( + (other_user_id, joining_user_id) + ) - if to_insert: - await self.store.add_users_who_share_private_room(room_id, to_insert) + if updates_to_users_who_share_rooms: + await self.store.add_users_who_share_private_room( + room_id, updates_to_users_who_share_rooms + ) async def _handle_remove_user(self, room_id: str, user_id: str) -> None: """Called when when someone leaves a room. The user may be local or remote. -- cgit 1.5.1 From a52c40e2a6d3a142c9cf768479ec963354c3e360 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 30 Sep 2022 20:10:50 -0500 Subject: Fix `get_users_in_room` mis-use in `transfer_room_state_on_room_upgrade` (#13960) Spawning from looking into `get_users_in_room` while investigating https://github.com/matrix-org/synapse/issues/13942#issuecomment-1262787050. See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755 for the original exploration around finding `get_users_in_room` mis-uses. Related to the following PRs where we also cleaned up some `get_users_in_room` mis-uses: - https://github.com/matrix-org/synapse/pull/13605 - https://github.com/matrix-org/synapse/pull/13608 - https://github.com/matrix-org/synapse/pull/13606 - https://github.com/matrix-org/synapse/pull/13958 --- changelog.d/13960.misc | 1 + synapse/handlers/room_member.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13960.misc (limited to 'synapse') diff --git a/changelog.d/13960.misc b/changelog.d/13960.misc new file mode 100644 index 0000000000..a7ba532bcb --- /dev/null +++ b/changelog.d/13960.misc @@ -0,0 +1 @@ +Use dedicated `get_local_users_in_room(room_id)` function to find local users when calculating users to copy over during a room upgrade. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 88158822e0..ee669eb30f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1150,8 +1150,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): logger.info("Transferring room state from %s to %s", old_room_id, room_id) # Find all local users that were in the old room and copy over each user's state - users = await self.store.get_users_in_room(old_room_id) - await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) + local_users = await self.store.get_local_users_in_room(old_room_id) + await self.copy_user_state_on_room_upgrade(old_room_id, room_id, local_users) # Add new room to the room directory if the old room was there # Remove old room from the room directory -- cgit 1.5.1 From 2769ef4df125f91b59693457052930379582d614 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 3 Oct 2022 04:14:45 -0500 Subject: Revert the general exception recording introduced in #13814 (#13969) * Maybe not catch all errors to avoid things in the nature-of CancelledError See https://github.com/matrix-org/synapse/pull/13815#discussion_r983384698 * Remove general exception tracking * Add changelog --- changelog.d/13969.misc | 1 + synapse/handlers/federation_event.py | 10 ---------- 2 files changed, 1 insertion(+), 10 deletions(-) create mode 100644 changelog.d/13969.misc (limited to 'synapse') diff --git a/changelog.d/13969.misc b/changelog.d/13969.misc new file mode 100644 index 0000000000..5ede0069c8 --- /dev/null +++ b/changelog.d/13969.misc @@ -0,0 +1 @@ +Revert catch-all exceptions being recorded as event pull attempt failures (only handle what we know about). diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3fac256881..778d8869b3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -866,11 +866,6 @@ class FederationEventHandler: event.room_id, event_id, str(err) ) return - except Exception as exc: - await self._store.record_event_failed_pull_attempt( - event.room_id, event_id, str(exc) - ) - raise exc try: try: @@ -913,11 +908,6 @@ class FederationEventHandler: logger.warning("Pulled event %s failed history check.", event_id) else: raise - except Exception as exc: - await self._store.record_event_failed_pull_attempt( - event.room_id, event_id, str(exc) - ) - raise exc @trace async def _compute_event_context_with_maybe_missing_prevs( -- cgit 1.5.1 From d65862c41f2992a253778753d7f378d3ef1fb996 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 3 Oct 2022 13:46:36 +0100 Subject: Refactor `_get_e2e_device_keys_txn` to split large queries (#13956) Instead of running a single large query, run a single query for user-only lookups and additional queries for batches of user device lookups. Resolves #13580. Signed-off-by: Sean Quah --- changelog.d/13956.bugfix | 1 + synapse/storage/database.py | 60 ++++++++++++++++ synapse/storage/databases/main/end_to_end_keys.py | 83 +++++++++++++++-------- 3 files changed, 115 insertions(+), 29 deletions(-) create mode 100644 changelog.d/13956.bugfix (limited to 'synapse') diff --git a/changelog.d/13956.bugfix b/changelog.d/13956.bugfix new file mode 100644 index 0000000000..5682c3e002 --- /dev/null +++ b/changelog.d/13956.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a252f8eaa0..b4469eb964 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2461,6 +2461,66 @@ def make_in_list_sql_clause( return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) +# These overloads ensure that `columns` and `iterable` values have the same length. +# Suppress "Single overload definition, multiple required" complaint. +@overload # type: ignore[misc] +def make_tuple_in_list_sql_clause( + database_engine: BaseDatabaseEngine, + columns: Tuple[str, str], + iterable: Collection[Tuple[Any, Any]], +) -> Tuple[str, list]: + ... + + +def make_tuple_in_list_sql_clause( + database_engine: BaseDatabaseEngine, + columns: Tuple[str, ...], + iterable: Collection[Tuple[Any, ...]], +) -> Tuple[str, list]: + """Returns an SQL clause that checks the given tuple of columns is in the iterable. + + Args: + database_engine + columns: Names of the columns in the tuple. + iterable: The tuples to check the columns against. + + Returns: + A tuple of SQL query and the args + """ + if len(columns) == 0: + # Should be unreachable due to mypy, as long as the overloads are set up right. + if () in iterable: + return "TRUE", [] + else: + return "FALSE", [] + + if len(columns) == 1: + # Use `= ANY(?)` on postgres. + return make_in_list_sql_clause( + database_engine, next(iter(columns)), [values[0] for values in iterable] + ) + + # There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as + # indices are not used when there are multiple columns. Instead, use an `IN` + # expression. + # + # `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas + # `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres. + # Thus, the latter is chosen. + + if len(iterable) == 0: + # A 0-length `VALUES` list is not allowed in sqlite or postgres. + # Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not + # allowed in postgres. + return "FALSE", [] + + tuple_sql = "(%s)" % (",".join("?" for _ in columns),) + return "(%s) IN (VALUES %s)" % ( + ",".join(column for column in columns), + ",".join(tuple_sql for _ in iterable), + ), [value for values in iterable for value in values] + + KV = TypeVar("KV") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 8e9e1b0b4b..8a10ae800c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -43,6 +43,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, make_in_list_sql_clause, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine @@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _get_e2e_device_keys_txn( self, txn: LoggingTransaction, - query_list: Collection[Tuple[str, str]], + query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: @@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker cross-signing signatures which have been added subsequently (for which, see get_e2e_device_keys_and_signatures) """ - query_clauses = [] - query_params = [] + query_clauses: List[str] = [] + query_params_list: List[List[object]] = [] if include_all_devices is False: include_deleted_devices = False @@ -297,40 +298,64 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if include_deleted_devices: deleted_devices = set(query_list) + # Split the query list into queries for users and queries for particular + # devices. + user_list = [] + user_device_list = [] for (user_id, device_id) in query_list: - query_clause = "user_id = ?" - query_params.append(user_id) - - if device_id is not None: - query_clause += " AND device_id = ?" - query_params.append(device_id) - - query_clauses.append(query_clause) - - sql = ( - "SELECT user_id, device_id, " - " d.display_name, " - " k.key_json" - " FROM devices d" - " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s AND NOT d.hidden" - ) % ( - "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses), - ) + if device_id is None: + user_list.append(user_id) + else: + user_device_list.append((user_id, device_id)) - txn.execute(sql, query_params) + if user_list: + user_id_in_list_clause, user_args = make_in_list_sql_clause( + txn.database_engine, "user_id", user_list + ) + query_clauses.append(user_id_in_list_clause) + query_params_list.append(user_args) + + if user_device_list: + # Divide the device queries into batches, to avoid excessively large + # queries. + for user_device_batch in batch_iter(user_device_list, 1024): + ( + user_device_id_in_list_clause, + user_device_args, + ) = make_tuple_in_list_sql_clause( + txn.database_engine, ("user_id", "device_id"), user_device_batch + ) + query_clauses.append(user_device_id_in_list_clause) + query_params_list.append(user_device_args) result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {} - for (user_id, device_id, display_name, key_json) in txn: - if include_deleted_devices: - deleted_devices.remove((user_id, device_id)) - result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( - display_name, db_to_json(key_json) if key_json else None + for query_clause, query_params in zip(query_clauses, query_params_list): + sql = ( + "SELECT user_id, device_id, " + " d.display_name, " + " k.key_json" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" + " WHERE %s AND NOT d.hidden" + ) % ( + "LEFT" if include_all_devices else "INNER", + query_clause, ) + txn.execute(sql, query_params) + + for (user_id, device_id, display_name, key_json) in txn: + assert device_id is not None + if include_deleted_devices: + deleted_devices.remove((user_id, device_id)) + result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( + display_name, db_to_json(key_json) if key_json else None + ) + if include_deleted_devices: for user_id, device_id in deleted_devices: + if device_id is None: + continue result.setdefault(user_id, {})[device_id] = None return result -- cgit 1.5.1 From 606b2d9009f0a3e70056dec7e9cdccd0c0d7afed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Oct 2022 14:13:11 +0100 Subject: Add cache to `get_partial_state_servers_at_join` (#14013) --- changelog.d/14013.misc | 1 + synapse/storage/databases/main/room.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/14013.misc (limited to 'synapse') diff --git a/changelog.d/14013.misc b/changelog.d/14013.misc new file mode 100644 index 0000000000..499e488c35 --- /dev/null +++ b/changelog.d/14013.misc @@ -0,0 +1 @@ +Faster room joins: Send device list updates to most servers in rooms with partial state. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 059eef5c22..7412bce255 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1134,6 +1134,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_rooms_for_retention_period_in_range_txn, ) + @cached(iterable=True) async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]: """Gets the list of servers in a partial state room at the time we joined it. @@ -1216,6 +1217,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): keyvalues={"room_id": room_id}, ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_partial_state_servers_at_join, (room_id,) + ) # We now delete anything from `device_lists_remote_pending` with a # stream ID less than the minimum @@ -1862,6 +1866,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): values=((room_id, s) for s in servers), ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_partial_state_servers_at_join, (room_id,) + ) async def write_partial_state_rooms_join_event_id( self, -- cgit 1.5.1 From a423f452942c5b1597c29be50b235c8df4d6c93d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 3 Oct 2022 14:26:49 +0100 Subject: Fix twisted trunk mypy errors (#14012) --- changelog.d/14012.misc | 1 + synapse/handlers/cas.py | 3 +++ synapse/handlers/ui_auth/checkers.py | 3 +++ 3 files changed, 7 insertions(+) create mode 100644 changelog.d/14012.misc (limited to 'synapse') diff --git a/changelog.d/14012.misc b/changelog.d/14012.misc new file mode 100644 index 0000000000..9888dc6cc1 --- /dev/null +++ b/changelog.d/14012.misc @@ -0,0 +1 @@ +Fix type annotations to be compatible with new annotations in development versions of twisted. diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 7163af8004..fc467bc7c1 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -130,6 +130,9 @@ class CasHandler: except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data + # Assertion is for mypy's benefit. Error.response is Optional[bytes], + # but a PartialDownloadError should always have a non-None response. + assert pde.response is not None body = pde.response except HttpResponseException as e: description = ( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index a744d68c64..332edcca24 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -119,6 +119,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response + # For mypy's benefit. A general Error.response is Optional[bytes], but + # a PartialDownloadError.response should be bytes AFAICS. + assert data is not None resp_body = json_decoder.decode(data.decode("utf-8")) if "success" in resp_body: -- cgit 1.5.1 From 719488dda87b04e4650a32f0c2b0b71782e0d48b Mon Sep 17 00:00:00 2001 From: lukasdenk <63459921+lukasdenk@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:30:45 +0100 Subject: Add query parameter `ts` to allow appservices set the `origin_server_ts` for state events. (#11866) MSC3316 declares that both /rooms/{roomId}/send and /rooms/{roomId}/state should accept a ts parameter for appservices. This change expands support to /state and adds tests. --- changelog.d/11866.feature | 1 + synapse/handlers/room_member.py | 13 +++++ synapse/rest/client/room.py | 34 +++++++----- tests/rest/client/test_rooms.py | 119 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 152 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11866.feature (limited to 'synapse') diff --git a/changelog.d/11866.feature b/changelog.d/11866.feature new file mode 100644 index 0000000000..0b52caf805 --- /dev/null +++ b/changelog.d/11866.feature @@ -0,0 +1 @@ +Allow application services to set the `origin_server_ts` of a state event by providing the query parameter `ts` in `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`, per [MSC3316](https://github.com/matrix-org/matrix-doc/pull/3316). Contributed by @lukasdenk. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ee669eb30f..6ad2b38b8f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -322,6 +322,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, historical: bool = False, + origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -361,6 +362,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + origin_server_ts: The origin_server_ts to use if a new event is created. Uses + the current timestamp if set to None. Returns: Tuple of event ID and stream ordering position @@ -399,6 +402,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): "state_key": user_id, # For backwards compatibility: "membership": membership, + "origin_server_ts": origin_server_ts, }, txn_id=txn_id, allow_no_prev_events=allow_no_prev_events, @@ -504,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -542,6 +547,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + origin_server_ts: The origin_server_ts to use if a new event is created. Uses + the current timestamp if set to None. Returns: A tuple of the new event ID and stream ID. @@ -583,6 +590,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, depth=depth, + origin_server_ts=origin_server_ts, ) return result @@ -606,6 +614,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -646,6 +655,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + origin_server_ts: The origin_server_ts to use if a new event is created. Uses + the current timestamp if set to None. Returns: A tuple of the new event ID and stream ID. @@ -785,6 +796,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, historical=historical, + origin_server_ts=origin_server_ts, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) @@ -1030,6 +1042,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content=content, require_consent=require_consent, outlier=outlier, + origin_server_ts=origin_server_ts, ) async def _should_perform_remote_join( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 0bca012535..b6dedbed04 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -268,15 +268,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - if state_key is not None: - event_dict["state_key"] = state_key + origin_server_ts = None + if requester.app_service: + origin_server_ts = parse_integer(request, "ts") try: if event_type == EventTypes.Member: @@ -287,8 +281,22 @@ class RoomStateEventRestServlet(TransactionRestServlet): room_id=room_id, action=membership, content=content, + origin_server_ts=origin_server_ts, ) else: + event_dict: JsonDict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if state_key is not None: + event_dict["state_key"] = state_key + + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts + ( event, _, @@ -333,10 +341,10 @@ class RoomSendEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } - # Twisted will have processed the args by now. - assert request.args is not None - if b"ts" in request.args and requester.app_service: - event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) + if requester.app_service: + origin_server_ts = parse_integer(request, "ts") + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts try: ( diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 7f8cf4fab0..5e66b5b26c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -20,7 +20,7 @@ import json from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch from urllib import parse as urlparse from parameterized import param, parameterized @@ -39,9 +39,10 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException +from synapse.appservice import ApplicationService from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, room, sync +from synapse.rest.client import account, directory, login, profile, register, room, sync from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock @@ -1252,6 +1253,120 @@ class RoomJoinTestCase(RoomBase): ) +class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): + servlets = [ + room.register_servlets, + synapse.rest.admin.register_servlets, + register.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.appservice_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + # Create a room as the appservice user. + args = { + "access_token": self.appservice.token, + "user_id": self.appservice_user, + } + channel = self.make_request( + "POST", + f"/_matrix/client/r0/createRoom?{urlparse.urlencode(args)}", + content={"visibility": "public"}, + ) + + assert channel.code == 200 + self.room = channel.json_body["room_id"] + + self.main_store = self.hs.get_datastores().main + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def test_send_event_ts(self) -> None: + """Test sending a non-state event with a custom timestamp.""" + ts = 1 + + url_params = { + "user_id": self.appservice_user, + "ts": ts, + } + channel = self.make_request( + "PUT", + path=f"/_matrix/client/r0/rooms/{self.room}/send/m.room.message/1234?" + + urlparse.urlencode(url_params), + content={"body": "test", "msgtype": "m.text"}, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + event_id = channel.json_body["event_id"] + + # Ensure the event was persisted with the correct timestamp. + res = self.get_success(self.main_store.get_event(event_id)) + self.assertEquals(ts, res.origin_server_ts) + + def test_send_state_event_ts(self) -> None: + """Test sending a state event with a custom timestamp.""" + ts = 1 + + url_params = { + "user_id": self.appservice_user, + "ts": ts, + } + channel = self.make_request( + "PUT", + path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.name?" + + urlparse.urlencode(url_params), + content={"name": "test"}, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + event_id = channel.json_body["event_id"] + + # Ensure the event was persisted with the correct timestamp. + res = self.get_success(self.main_store.get_event(event_id)) + self.assertEquals(ts, res.origin_server_ts) + + def test_send_membership_event_ts(self) -> None: + """Test sending a membership event with a custom timestamp.""" + ts = 1 + + url_params = { + "user_id": self.appservice_user, + "ts": ts, + } + channel = self.make_request( + "PUT", + path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.member/{self.appservice_user}?" + + urlparse.urlencode(url_params), + content={"membership": "join", "display_name": "test"}, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + event_id = channel.json_body["event_id"] + + # Ensure the event was persisted with the correct timestamp. + res = self.get_success(self.main_store.get_event(event_id)) + self.assertEquals(ts, res.origin_server_ts) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From 2c237debd3476bcc45a76e360b0cb33032b23045 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Oct 2022 14:45:19 +0100 Subject: Fix bug where we didn't delete staging push actions (#14014) Introduced in #13719 --- changelog.d/14014.bugfix | 1 + synapse/storage/databases/main/events.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14014.bugfix (limited to 'synapse') diff --git a/changelog.d/14014.bugfix b/changelog.d/14014.bugfix new file mode 100644 index 0000000000..4318f4daff --- /dev/null +++ b/changelog.d/14014.bugfix @@ -0,0 +1 @@ +Send invite push notifications for invite over federation. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index bb489b8189..3e15827986 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2174,7 +2174,7 @@ class PersistEventsStore: ( (event.event_id,) for event, _ in all_events_and_contexts - if not event.internal_metadata.is_outlier() + if event.internal_metadata.is_notifiable() ), ) -- cgit 1.5.1 From b706111b7805dceb268e114b6c291c4318288cf0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 3 Oct 2022 12:47:15 -0400 Subject: Do not return unspecced original_event field when using the stable /relations endpoint. (#14025) Keep the old behavior (of including the original_event field) for any requests to the /unstable version of the endpoint, but do not include the field when the /v1 version is used. This should avoid new clients from depending on this field, but will not help with current dependencies. --- changelog.d/14025.bugfix | 1 + synapse/handlers/relations.py | 25 +++++++++++++------------ synapse/rest/client/relations.py | 6 ++++++ tests/rest/client/test_relations.py | 13 ++++++++----- 4 files changed, 28 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14025.bugfix (limited to 'synapse') diff --git a/changelog.d/14025.bugfix b/changelog.d/14025.bugfix new file mode 100644 index 0000000000..391364f44d --- /dev/null +++ b/changelog.d/14025.bugfix @@ -0,0 +1 @@ +Do not return an unspecified `original_event` field when using the stable `/relations` endpoint. Introduced in Synapse v1.57.0. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 28d7093f08..63bc6a7aa5 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -78,6 +78,7 @@ class RelationsHandler: direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, + include_original_event: bool = False, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. @@ -94,6 +95,7 @@ class RelationsHandler: oldest first (`"f"`). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. + include_original_event: Whether to include the parent event. Returns: The pagination chunk. @@ -138,25 +140,24 @@ class RelationsHandler: is_peeking=(member_event_id is None), ) - now = self._clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) # The relations returned for the requested event do include their # bundled aggregations. aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - return_value = { - "chunk": serialized_events, - "original_event": original_event, + now = self._clock.time_msec() + return_value: JsonDict = { + "chunk": self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ), } + if include_original_event: + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. + return_value["original_event"] = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None + ) if next_token: return_value["next_batch"] = await next_token.to_string(self._main_store) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 205c556f64..7a25de5c85 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -82,6 +82,11 @@ class RelationPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) + # The unstable version of this API returns an extra field for client + # compatibility, see https://github.com/matrix-org/synapse/issues/12930. + assert request.path is not None + include_original_event = request.path.startswith(b"/_matrix/client/unstable/") + result = await self._relations_handler.get_relations( requester=requester, event_id=parent_id, @@ -92,6 +97,7 @@ class RelationPaginationServlet(RestServlet): direction=direction, from_token=from_token, to_token=to_token, + include_original_event=include_original_event, ) return 200, result diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index fef3b72d76..988cdb746d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -654,6 +654,14 @@ class RelationsTestCase(BaseRelationsTestCase): ) # We also expect to get the original event (the id of which is self.parent_id) + # when requesting the unstable endpoint. + self.assertNotIn("original_event", channel.json_body) + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -755,11 +763,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( -- cgit 1.5.1 From b381701f8c07444fb86d80a79f561c8468a6c0b7 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 3 Oct 2022 17:16:15 +0000 Subject: Announce that legacy metric names are deprecated, will be turned off by default in Synapse v1.71.0 and removed altogether in Synapse v1.73.0. (#14024) --- changelog.d/14024.removal | 1 + docs/metrics-howto.md | 11 +++++++++- docs/upgrade.md | 28 ++++++++++++++++++++++++ docs/usage/configuration/config_documentation.md | 25 +++++++++++++++++++++ synapse/config/metrics.py | 26 ---------------------- 5 files changed, 64 insertions(+), 27 deletions(-) create mode 100644 changelog.d/14024.removal (limited to 'synapse') diff --git a/changelog.d/14024.removal b/changelog.d/14024.removal new file mode 100644 index 0000000000..9b83cb3927 --- /dev/null +++ b/changelog.d/14024.removal @@ -0,0 +1 @@ +Announce that legacy metric names are deprecated, will be turned off by default in Synapse v1.71.0 and removed altogether in Synapse v1.73.0. See the upgrade notes for more information. \ No newline at end of file diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md index 279303a798..d8416b5a5f 100644 --- a/docs/metrics-howto.md +++ b/docs/metrics-howto.md @@ -135,6 +135,8 @@ Synapse 1.2 updates the Prometheus metrics to match the naming convention of the upstream `prometheus_client`. The old names are considered deprecated and will be removed in a future version of Synapse. +**The old names will be disabled by default in Synapse v1.71.0 and removed +altogether in Synapse v1.73.0.** | New Name | Old Name | | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------- | @@ -146,6 +148,13 @@ Synapse. | synapse_federation_client_events_processed_total | synapse_federation_client_events_processed | | synapse_event_processing_loop_count_total | synapse_event_processing_loop_count | | synapse_event_processing_loop_room_count_total | synapse_event_processing_loop_room_count | +| synapse_util_caches_cache_hits | synapse_util_caches_cache:hits | +| synapse_util_caches_cache_size | synapse_util_caches_cache:size | +| synapse_util_caches_cache_evicted_size | synapse_util_caches_cache:evicted_size | +| synapse_util_caches_cache | synapse_util_caches_cache:total | +| synapse_util_caches_response_cache_size | synapse_util_caches_response_cache:size | +| synapse_util_caches_response_cache_hits | synapse_util_caches_response_cache:hits | +| synapse_util_caches_response_cache_evicted_size | synapse_util_caches_response_cache:evicted_size | | synapse_util_metrics_block_count_total | synapse_util_metrics_block_count | | synapse_util_metrics_block_time_seconds_total | synapse_util_metrics_block_time_seconds | | synapse_util_metrics_block_ru_utime_seconds_total | synapse_util_metrics_block_ru_utime_seconds | @@ -261,7 +270,7 @@ Standard Metric Names As of synapse version 0.18.2, the format of the process-wide metrics has been changed to fit prometheus standard naming conventions. Additionally -the units have been changed to seconds, from miliseconds. +the units have been changed to seconds, from milliseconds. | New name | Old name | | ---------------------------------------- | --------------------------------- | diff --git a/docs/upgrade.md b/docs/upgrade.md index c4db19e23d..002ef70059 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -100,6 +100,34 @@ vice versa. Once all workers are upgraded to v1.69 (or downgraded to v1.68), receipts replication will resume as normal. + +## Deprecation of legacy Prometheus metric names + +In current versions of Synapse, some Prometheus metrics are emitted under two different names, +with one of the names being older but non-compliant with OpenMetrics and Prometheus conventions +and one of the names being newer but compliant. + +Synapse v1.71.0 will turn the old metric names off *by default*. +For administrators that still rely on them and have not had chance to update their +uses of the metrics, it's possible to specify `enable_legacy_metrics: true` in +the configuration to re-enable them temporarily. + +Synapse v1.73.0 will **remove legacy metric names altogether** and it will no longer +be possible to re-enable them. + +The Grafana dashboard, Prometheus recording rules and Prometheus Consoles included +in the `contrib` directory in the Synapse repository have been updated to no longer +rely on the legacy names. These can be used on a current version of Synapse +because current versions of Synapse emit both old and new names. + +You may need to update your alerting rules or any other rules that depend on +the names of Prometheus metrics. +If you want to test your changes before legacy names are disabled by default, +you may specify `enable_legacy_metrics: false` in your homeserver configuration. + +A list of affected metrics is available on the [Metrics How-to page](https://matrix-org.github.io/synapse/v1.69/metrics-howto.html?highlight=metrics%20deprecated#renaming-of-metrics--deprecation-of-old-names-in-12). + + # Upgrading to v1.68.0 Two changes announced in the upgrade notes for v1.67.0 have now landed in v1.68.0. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index f46b4932fd..5e40166ff5 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2436,6 +2436,31 @@ Example configuration: enable_metrics: true ``` --- +### `enable_legacy_metrics` + +Set to `true` to publish both legacy and non-legacy Prometheus metric names, +or to `false` to only publish non-legacy Prometheus metric names. +Defaults to `true`. Has no effect if `enable_metrics` is `false`. +**In Synapse v1.71.0, this will default to `false` before being removed in Synapse v1.73.0.** + +Legacy metric names include: +- metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; +- counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. + +These legacy metric names are unconventional and not compliant with OpenMetrics standards. +They are included for backwards compatibility. + +Example configuration: +```yaml +enable_legacy_metrics: false +``` + +See https://github.com/matrix-org/synapse/issues/11106 for context. + +*Since v1.67.0.* + +**Will be removed in v1.73.0.** +--- ### `sentry` Use this option to enable sentry integration. Provide the DSN assigned to you by sentry diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index f3134834e5..bb065f9f2f 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -43,32 +43,6 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) - """ - ### `enable_legacy_metrics` (experimental) - - **Experimental: this option may be removed or have its behaviour - changed at any time, with no notice.** - - Set to `true` to publish both legacy and non-legacy Prometheus metric names, - or to `false` to only publish non-legacy Prometheus metric names. - Defaults to `true`. Has no effect if `enable_metrics` is `false`. - - Legacy metric names include: - - metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; - - counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. - - These legacy metric names are unconventional and not compliant with OpenMetrics standards. - They are included for backwards compatibility. - - Example configuration: - ```yaml - enable_legacy_metrics: false - ``` - - See https://github.com/matrix-org/synapse/issues/11106 for context. - - *Since v1.67.0.* - """ self.enable_legacy_metrics = config.get("enable_legacy_metrics", True) self.report_stats = config.get("report_stats", None) -- cgit 1.5.1 From 5a6d02524685187b8ed212b8e8027e4d15575fd0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Oct 2022 18:44:44 +0100 Subject: Clear out old rows from `event_push_actions_staging` (#14020) On matrix.org we have ~5 million stale rows in `event_push_actions_staging`, let's add a background job to make sure we clear them out. --- changelog.d/14020.misc | 1 + .../storage/databases/main/event_push_actions.py | 58 +++++++++++++++++++++- synapse/storage/schema/__init__.py | 1 + .../main/delta/73/05old_push_actions.sql.postgres | 22 ++++++++ .../main/delta/73/05old_push_actions.sql.sqlite | 24 +++++++++ 5 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14020.misc create mode 100644 synapse/storage/schema/main/delta/73/05old_push_actions.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/05old_push_actions.sql.sqlite (limited to 'synapse') diff --git a/changelog.d/14020.misc b/changelog.d/14020.misc new file mode 100644 index 0000000000..85550b307d --- /dev/null +++ b/changelog.d/14020.misc @@ -0,0 +1 @@ +Clear out stale entries in `event_push_actions_staging` table. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3fdf128d9e..cdc9ee5a37 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -205,6 +205,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ): super().__init__(database, db_conn, hs) + # Track when the process started. + self._started_ts = self._clock.time_msec() + # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago: Optional[int] = None self.stream_ordering_day_ago: Optional[int] = None @@ -224,6 +227,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._rotate_notifs, 30 * 1000 ) + self._clear_old_staging_loop = self._clock.looping_call( + self._clear_old_push_actions_staging, 30 * 60 * 1000 + ) + self.db_pool.updates.register_background_index_update( "event_push_summary_unique_index", index_name="event_push_summary_unique_index", @@ -791,7 +798,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( user_id: str, actions: Collection[Union[Mapping, str]] - ) -> Tuple[str, str, str, int, int, int, str]: + ) -> Tuple[str, str, str, int, int, int, str, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( @@ -802,6 +809,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas is_highlight, # highlight column int(count_as_unread), # unread column thread_id, # thread_id column + self._clock.time_msec(), # inserted_ts column ) await self.db_pool.simple_insert_many( @@ -814,6 +822,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "highlight", "unread", "thread_id", + "inserted_ts", ), values=[ _gen_entry(user_id, actions) @@ -1340,6 +1349,53 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas if done: break + @wrap_as_background_process("_clear_old_push_actions_staging") + async def _clear_old_push_actions_staging(self) -> None: + """Clear out any old event push actions from the staging table for + events that we failed to persist. + """ + + # We delete anything more than an hour old, on the assumption that we'll + # never take more than an hour to persist an event. + delete_before_ts = self._clock.time_msec() - 60 * 60 * 1000 + + if self._started_ts > delete_before_ts: + # We need to wait for at least an hour before we started deleting, + # so that we know it's safe to delete rows with NULL `inserted_ts`. + return + + # We don't have an index on `inserted_ts`, instead we assume that the + # number of "live" rows in `event_push_actions_staging` is small enough + # that an infrequent periodic scan won't cause a problem. + # + # Note: we also delete any columns with NULL `inserted_ts`, this is safe + # as we added a default value to new rows and so they must be at least + # an hour old. + limit = 1000 + sql = """ + DELETE FROM event_push_actions_staging WHERE event_id IN ( + SELECT event_id FROM event_push_actions_staging WHERE + inserted_ts < ? OR inserted_ts IS NULL + LIMIT ? + ) + """ + + def _clear_old_push_actions_staging_txn(txn: LoggingTransaction) -> bool: + txn.execute(sql, (delete_before_ts, limit)) + return txn.rowcount >= limit + + while True: + # Returns true if we have more stuff to delete from the table. + deleted = await self.db_pool.runInteraction( + "_clear_old_push_actions_staging", _clear_old_push_actions_staging_txn + ) + + if not deleted: + return + + # We sleep to ensure that we don't overwhelm the DB. + await self._clock.sleep(1.0) + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index f29424d17a..4a5c947699 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -85,6 +85,7 @@ Changes in SCHEMA_VERSION = 73; events over federation. - Add indexes to various tables (`event_failed_pull_attempts`, `insertion_events`, `batch_events`) to make it easy to delete all associated rows when purging a room. + - `inserted_ts` column is added to `event_push_actions_staging` table. """ diff --git a/synapse/storage/schema/main/delta/73/05old_push_actions.sql.postgres b/synapse/storage/schema/main/delta/73/05old_push_actions.sql.postgres new file mode 100644 index 0000000000..4af1a8470b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/05old_push_actions.sql.postgres @@ -0,0 +1,22 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column so that we know when a push action was inserted, to make it +-- easier to clear out old ones. +ALTER TABLE event_push_actions_staging ADD COLUMN inserted_ts BIGINT; + +-- We now add a default for *new* rows. We don't do this above as we don't want +-- to have to update every remove with the new default. +ALTER TABLE event_push_actions_staging ALTER COLUMN inserted_ts SET DEFAULT extract(epoch from now()) * 1000; diff --git a/synapse/storage/schema/main/delta/73/05old_push_actions.sql.sqlite b/synapse/storage/schema/main/delta/73/05old_push_actions.sql.sqlite new file mode 100644 index 0000000000..7482dabba2 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/05old_push_actions.sql.sqlite @@ -0,0 +1,24 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- On SQLite we must be in monolith mode and updating the database from Synapse, +-- so its safe to assume that `event_push_actions_staging` should be empty (as +-- over restart an event must either have been fully persisted or we'll +-- recalculate the push actions) +DELETE FROM event_push_actions_staging; + +-- Add a column so that we know when a push action was inserted, to make it +-- easier to clear out old ones. +ALTER TABLE event_push_actions_staging ADD COLUMN inserted_ts BIGINT; -- cgit 1.5.1 From 70a4317692adcf7f1dacb201cda2188c8495bfa9 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 3 Oct 2022 14:53:29 -0500 Subject: Track when the pulled event signature fails (#13815) Because we're doing the recording in `_check_sigs_and_hash_for_pulled_events_and_fetch` (previously named `_check_sigs_and_hash_and_fetch`), this means we will track signature failures for `backfill`, `get_room_state`, `get_event_auth`, and `get_missing_events` (all pulled event scenarios). And we also record signature failures from `get_pdu`. Part of https://github.com/matrix-org/synapse/issues/13700 Part of https://github.com/matrix-org/synapse/issues/13676 and https://github.com/matrix-org/synapse/issues/13356 This PR will be especially important for https://github.com/matrix-org/synapse/pull/13816 so we can avoid the costly `_get_state_ids_after_missing_prev_event` down the line when `/messages` calls backfill. --- changelog.d/13815.feature | 1 + synapse/federation/federation_base.py | 25 ++++++++-- synapse/federation/federation_client.py | 50 ++++++++++++++++---- tests/federation/test_federation_client.py | 75 ++++++++++++++++++++++++++++++ tests/test_federation.py | 4 +- 5 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 changelog.d/13815.feature (limited to 'synapse') diff --git a/changelog.d/13815.feature b/changelog.d/13815.feature new file mode 100644 index 0000000000..ba411f5067 --- /dev/null +++ b/changelog.d/13815.feature @@ -0,0 +1 @@ +Keep track when an event pulled over federation fails its signature check so we can intelligently back-off in the future. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index abe2c1971a..6bd4742140 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable, Callable, Optional from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -58,7 +58,12 @@ class FederationBase: @trace async def _check_sigs_and_hash( - self, room_version: RoomVersion, pdu: EventBase + self, + room_version: RoomVersion, + pdu: EventBase, + record_failure_callback: Optional[ + Callable[[EventBase, str], Awaitable[None]] + ] = None, ) -> EventBase: """Checks that event is correctly signed by the sending server. @@ -70,6 +75,11 @@ class FederationBase: Args: room_version: The room version of the PDU pdu: the event to be checked + record_failure_callback: A callback to run whenever the given event + fails signature or hash checks. This includes exceptions + that would be normally be thrown/raised but also things like + checking for event tampering where we just return the redacted + event. Returns: * the original event if the checks pass @@ -80,7 +90,12 @@ class FederationBase: InvalidEventSignatureError if the signature check failed. Nothing will be logged in this case. """ - await _check_sigs_on_pdu(self.keyring, room_version, pdu) + try: + await _check_sigs_on_pdu(self.keyring, room_version, pdu) + except InvalidEventSignatureError as exc: + if record_failure_callback: + await record_failure_callback(pdu, str(exc)) + raise exc if not check_event_content_hash(pdu): # let's try to distinguish between failures because the event was @@ -116,6 +131,10 @@ class FederationBase: "event_id": pdu.event_id, } ) + if record_failure_callback: + await record_failure_callback( + pdu, "Event content has been tampered with" + ) return redacted_event spam_check = await self.spam_checker.check_event_for_spam(pdu) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 464672a3da..4dca711cd2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -278,7 +278,7 @@ class FederationClient(FederationBase): pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus] # Check signatures and hash of pdus, removing any from the list that fail checks - pdus[:] = await self._check_sigs_and_hash_and_fetch( + pdus[:] = await self._check_sigs_and_hash_for_pulled_events_and_fetch( dest, pdus, room_version=room_version ) @@ -328,7 +328,17 @@ class FederationClient(FederationBase): # Check signatures are correct. try: - signed_pdu = await self._check_sigs_and_hash(room_version, pdu) + + async def _record_failure_callback( + event: EventBase, cause: str + ) -> None: + await self.store.record_event_failed_pull_attempt( + event.room_id, event.event_id, cause + ) + + signed_pdu = await self._check_sigs_and_hash( + room_version, pdu, _record_failure_callback + ) except InvalidEventSignatureError as e: errmsg = f"event id {pdu.event_id}: {e}" logger.warning("%s", errmsg) @@ -547,24 +557,28 @@ class FederationClient(FederationBase): len(auth_event_map), ) - valid_auth_events = await self._check_sigs_and_hash_and_fetch( + valid_auth_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch( destination, auth_event_map.values(), room_version ) - valid_state_events = await self._check_sigs_and_hash_and_fetch( - destination, state_event_map.values(), room_version + valid_state_events = ( + await self._check_sigs_and_hash_for_pulled_events_and_fetch( + destination, state_event_map.values(), room_version + ) ) return valid_state_events, valid_auth_events @trace - async def _check_sigs_and_hash_and_fetch( + async def _check_sigs_and_hash_for_pulled_events_and_fetch( self, origin: str, pdus: Collection[EventBase], room_version: RoomVersion, ) -> List[EventBase]: - """Checks the signatures and hashes of a list of events. + """ + Checks the signatures and hashes of a list of pulled events we got from + federation and records any signature failures as failed pull attempts. If a PDU fails its signature check then we check if we have it in the database, and if not then request it from the sender's server (if that @@ -597,11 +611,17 @@ class FederationClient(FederationBase): valid_pdus: List[EventBase] = [] + async def _record_failure_callback(event: EventBase, cause: str) -> None: + await self.store.record_event_failed_pull_attempt( + event.room_id, event.event_id, cause + ) + async def _execute(pdu: EventBase) -> None: valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=pdu, origin=origin, room_version=room_version, + record_failure_callback=_record_failure_callback, ) if valid_pdu: @@ -618,6 +638,9 @@ class FederationClient(FederationBase): pdu: EventBase, origin: str, room_version: RoomVersion, + record_failure_callback: Optional[ + Callable[[EventBase, str], Awaitable[None]] + ] = None, ) -> Optional[EventBase]: """Takes a PDU and checks its signatures and hashes. @@ -634,6 +657,11 @@ class FederationClient(FederationBase): origin pdu room_version + record_failure_callback: A callback to run whenever the given event + fails signature or hash checks. This includes exceptions + that would be normally be thrown/raised but also things like + checking for event tampering where we just return the redacted + event. Returns: The PDU (possibly redacted) if it has valid signatures and hashes. @@ -641,7 +669,9 @@ class FederationClient(FederationBase): """ try: - return await self._check_sigs_and_hash(room_version, pdu) + return await self._check_sigs_and_hash( + room_version, pdu, record_failure_callback + ) except InvalidEventSignatureError as e: logger.warning( "Signature on retrieved event %s was invalid (%s). " @@ -694,7 +724,7 @@ class FederationClient(FederationBase): auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]] - signed_auth = await self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_for_pulled_events_and_fetch( destination, auth_chain, room_version=room_version ) @@ -1401,7 +1431,7 @@ class FederationClient(FederationBase): event_from_pdu_json(e, room_version) for e in content.get("events", []) ] - signed_events = await self._check_sigs_and_hash_and_fetch( + signed_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch( destination, events, room_version=room_version ) except HttpResponseException as e: diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 50e376f695..a538215931 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -23,14 +23,23 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock +from tests.test_utils import event_injection from tests.unittest import FederatingHomeserverTestCase class FederationClientTest(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): super().prepare(reactor, clock, homeserver) @@ -231,6 +240,72 @@ class FederationClientTest(FederatingHomeserverTestCase): return remote_pdu + def test_backfill_invalid_signature_records_failed_pull_attempts( + self, + ) -> None: + """ + Test to make sure that events from /backfill with invalid signatures get + recorded as failed pull attempts. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + # We purposely don't run `add_hashes_and_signatures_from_other_server` + # over this because we want the signature check to fail. + pulled_event, _ = self.get_success( + event_injection.create_event( + self.hs, + room_id=room_id, + sender=OTHER_USER, + type="test_event_type", + content={"body": "garply"}, + ) + ) + + # We expect an outbound request to /backfill, so stub that out + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( + _mock_response( + { + "origin": "yet.another.server", + "origin_server_ts": 900, + # Mimic the other server returning our new `pulled_event` + "pdus": [pulled_event.get_pdu_json()], + } + ) + ) + + self.get_success( + self.hs.get_federation_client().backfill( + # We use "yet.another.server" instead of + # `self.OTHER_SERVER_NAME` because we want to see the behavior + # from `_check_sigs_and_hash_and_fetch_one` where it tries to + # fetch the PDU again from the origin server if the signature + # fails. Just want to make sure that the failure is counted from + # both code paths. + dest="yet.another.server", + room_id=room_id, + limit=1, + extremities=[pulled_event.event_id], + ), + ) + + # Make sure our failed pull attempt was recorded + backfill_num_attempts = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ) + ) + # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the + # other from "yet.another.server" + self.assertEqual(backfill_num_attempts, 2) + def _mock_response(resp: JsonDict): body = json.dumps(resp).encode("utf-8") diff --git a/tests/test_federation.py b/tests/test_federation.py index 779fad1f63..80e5c590d8 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -86,8 +86,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): federation_event_handler._check_event_auth = _check_event_auth self.client = self.homeserver.get_federation_client() - self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( - pdus + self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( + lambda dest, pdus, **k: succeed(pdus) ) # Send the join, it should return None (which is not an error) -- cgit 1.5.1 From 27fa0fa6987c691bf6a8528bb870503d2869a740 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 07:06:41 -0400 Subject: Send the appservice access token as a header. (#13996) Implements MSC2832 by sending application service access tokens in the Authorization header. The access token is also still sent as a query parameter until the application service ecosystem has fully migrated to using headers. In the future this could be made opt-in, or removed completely. --- changelog.d/13996.feature | 1 + synapse/appservice/api.py | 23 +++++++++++++++++++---- tests/appservice/test_api.py | 8 ++++++-- 3 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13996.feature (limited to 'synapse') diff --git a/changelog.d/13996.feature b/changelog.d/13996.feature new file mode 100644 index 0000000000..771f1c97a3 --- /dev/null +++ b/changelog.d/13996.feature @@ -0,0 +1 @@ +Send application service access tokens as a header (and query parameter). Implement [MSC2832](https://github.com/matrix-org/matrix-spec-proposals/pull/2832). diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 0963fb3bb4..fbac4375b0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -120,7 +120,11 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) try: - response = await self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json( + uri, + {"access_token": service.hs_token}, + headers={"Authorization": f"Bearer {service.hs_token}"}, + ) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -140,7 +144,11 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) try: - response = await self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json( + uri, + {"access_token": service.hs_token}, + headers={"Authorization": f"Bearer {service.hs_token}"}, + ) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -181,7 +189,9 @@ class ApplicationServiceApi(SimpleHttpClient): **fields, b"access_token": service.hs_token, } - response = await self.get_json(uri, args=args) + response = await self.get_json( + uri, args=args, headers={"Authorization": f"Bearer {service.hs_token}"} + ) if not isinstance(response, list): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response @@ -217,7 +227,11 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - info = await self.get_json(uri, {"access_token": service.hs_token}) + info = await self.get_json( + uri, + {"access_token": service.hs_token}, + headers={"Authorization": f"Bearer {service.hs_token}"}, + ) if not _is_valid_3pe_metadata(info): logger.warning( @@ -313,6 +327,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri=uri, json_body=body, args={"access_token": service.hs_token}, + headers={"Authorization": f"Bearer {service.hs_token}"}, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 532b676365..11008ac1fb 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -69,10 +69,14 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): self.request_url = None - async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]: - if not args.get(b"access_token"): + async def get_json( + url: str, args: Mapping[Any, Any], headers: Mapping[Any, Any] + ) -> List[JsonDict]: + # Ensure the access token is passed as both a header and query arg. + if not headers.get("Authorization") or not args.get(b"access_token"): raise RuntimeError("Access token not provided") + self.assertEqual(headers.get("Authorization"), f"Bearer {TOKEN}") self.assertEqual(args.get(b"access_token"), TOKEN) self.request_url = url if url == URL_USER: -- cgit 1.5.1 From e70c6b720ed537c0b7fc0cd4aa20eac195941d73 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 07:08:27 -0400 Subject: Disable pushing for server ACL events (MSC3786). (#13997) Switches to the stable identifier for MSC3786 and enables it by default. This disables pushes of m.room.server_acl events. --- changelog.d/13997.feature | 1 + rust/src/push/base_rules.rs | 2 +- rust/src/push/mod.rs | 9 --------- stubs/synapse/synapse_rust/push.pyi | 6 +----- synapse/config/experimental.py | 3 --- synapse/storage/databases/main/push_rule.py | 9 ++------- 6 files changed, 5 insertions(+), 25 deletions(-) create mode 100644 changelog.d/13997.feature (limited to 'synapse') diff --git a/changelog.d/13997.feature b/changelog.d/13997.feature new file mode 100644 index 0000000000..23f7ed106f --- /dev/null +++ b/changelog.d/13997.feature @@ -0,0 +1 @@ +Ignore server ACL changes when generating pushes. Implement [MSC3786](https://github.com/matrix-org/matrix-spec-proposals/pull/3786). diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index bb59676bde..2a09cf99ae 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -173,7 +173,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default_enabled: true, }, PushRule { - rule_id: Cow::Borrowed("global/override/.org.matrix.msc3786.rule.room.server_acl"), + rule_id: Cow::Borrowed("global/override/.m.rule.room.server_acl"), priority_class: 5, conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 30fffc31ad..208b9c0d73 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -401,7 +401,6 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3786_enabled: bool, msc3772_enabled: bool, } @@ -411,13 +410,11 @@ impl FilteredPushRules { pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, - msc3786_enabled: bool, msc3772_enabled: bool, ) -> Self { Self { push_rules, enabled_map, - msc3786_enabled, msc3772_enabled, } } @@ -437,12 +434,6 @@ impl FilteredPushRules { .iter() .filter(|rule| { // Ignore disabled experimental push rules - if !self.msc3786_enabled - && rule.rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" - { - return false; - } - if !self.msc3772_enabled && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" { diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index fffb8419c6..5900e61450 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -26,11 +26,7 @@ class PushRules: class FilteredPushRules: def __init__( - self, - push_rules: PushRules, - enabled_map: Dict[str, bool], - msc3786_enabled: bool, - msc3772_enabled: bool, + self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3772_enabled: bool ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 31834fb27d..83695f24d9 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,9 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3786 (Add a default push rule to ignore m.room.server_acl events) - self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) - # MSC3771: Thread read receipts self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ed17b2e70c..8295322b0e 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -81,15 +81,10 @@ def _load_rules( for rawrule in rawrules ] - push_rules = PushRules( - ruleslist, - ) + push_rules = PushRules(ruleslist) filtered_rules = FilteredPushRules( - push_rules, - enabled_map, - msc3786_enabled=experimental_config.msc3786_enabled, - msc3772_enabled=experimental_config.msc3772_enabled, + push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled ) return filtered_rules -- cgit 1.5.1 From b4ec4f5e71a87d5bdc840a4220dfd9a34c54c847 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 09:47:04 -0400 Subject: Track notification counts per thread (implement MSC3773). (#13776) When retrieving counts of notifications segment the results based on the thread ID, but choose whether to return them as individual threads or as a single summed field by letting the client opt-in via a sync flag. The summarization code is also updated to be per thread, instead of per room. --- changelog.d/13776.feature | 1 + synapse/api/constants.py | 3 + synapse/api/filtering.py | 10 ++ synapse/config/experimental.py | 2 + synapse/handlers/sync.py | 40 ++++- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_tools.py | 9 +- synapse/rest/client/sync.py | 4 + synapse/rest/client/versions.py | 3 +- synapse/storage/database.py | 2 +- .../storage/databases/main/event_push_actions.py | 188 +++++++++++++-------- synapse/storage/schema/__init__.py | 6 +- .../delta/73/06thread_notifications_backfill.sql | 29 ++++ .../07thread_notifications_not_null.sql.postgres | 19 +++ .../73/07thread_notifications_not_null.sql.sqlite | 101 +++++++++++ tests/replication/slave/storage/test_events.py | 17 +- tests/storage/test_event_push_actions.py | 169 +++++++++++++++++- 17 files changed, 514 insertions(+), 93 deletions(-) create mode 100644 changelog.d/13776.feature create mode 100644 synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite (limited to 'synapse') diff --git a/changelog.d/13776.feature b/changelog.d/13776.feature new file mode 100644 index 0000000000..22bce125ce --- /dev/null +++ b/changelog.d/13776.feature @@ -0,0 +1 @@ +Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c031903b1a..44c5ffc6a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -31,6 +31,9 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 +# Constant value used for the pseudo-thread which is the main timeline. +MAIN_TIMELINE: Final = "main" + class Membership: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f7f46f8d80..c6e44dcf82 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"}, # Include or exclude events with the provided labels. # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, @@ -240,6 +241,9 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members + def unread_thread_notifications(self) -> bool: + return self._room_timeline_filter.unread_thread_notifications + async def filter_presence( self, events: Iterable[UserPresenceState] ) -> List[UserPresenceState]: @@ -304,6 +308,12 @@ class Filter: self.include_redundant_members = filter_json.get( "include_redundant_members", False ) + if hs.config.experimental.msc3773_enabled: + self.unread_thread_notifications: bool = filter_json.get( + "org.matrix.msc3773.unread_thread_notifications", False + ) + else: + self.unread_thread_notifications = False self.types = filter_json.get("types", None) self.not_types = filter_json.get("not_types", []) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 83695f24d9..6503ce6e34 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -99,6 +99,8 @@ class ExperimentalConfig(Config): self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) + # MSC3773: Thread notifications + self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) # MSC3715: dir param on /relations. self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4abb9b6127..329e89c604 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -40,7 +40,7 @@ from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -128,6 +128,7 @@ class JoinedSyncResult: ephemeral: List[JsonDict] account_data: List[JsonDict] unread_notifications: JsonDict + unread_thread_notifications: JsonDict summary: Optional[JsonDict] unread_count: int @@ -278,6 +279,8 @@ class SyncHandler: self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + self._msc3773_enabled = hs.config.experimental.msc3773_enabled + async def wait_for_sync_for_user( self, requester: Requester, @@ -1288,7 +1291,7 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> NotifCounts: + ) -> RoomNotifCounts: with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( @@ -2353,6 +2356,7 @@ class SyncHandler: ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + unread_thread_notifications={}, summary=summary, unread_count=0, ) @@ -2360,10 +2364,36 @@ class SyncHandler: if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - unread_notifications["notification_count"] = notifs.notify_count - unread_notifications["highlight_count"] = notifs.highlight_count + # Notifications for the main timeline. + notify_count = notifs.main_timeline.notify_count + highlight_count = notifs.main_timeline.highlight_count + unread_count = notifs.main_timeline.unread_count - room_sync.unread_count = notifs.unread_count + # Check the sync configuration. + if ( + self._msc3773_enabled + and sync_config.filter_collection.unread_thread_notifications() + ): + # And add info for each thread. + room_sync.unread_thread_notifications = { + thread_id: { + "notification_count": thread_notifs.notify_count, + "highlight_count": thread_notifs.highlight_count, + } + for thread_id, thread_notifs in notifs.threads.items() + if thread_id is not None + } + + else: + # Combine the unread counts for all threads and main timeline. + for thread_notifs in notifs.threads.values(): + notify_count += thread_notifs.notify_count + highlight_count += thread_notifs.highlight_count + unread_count += thread_notifs.unread_count + + unread_notifications["notification_count"] = notify_count + unread_notifications["highlight_count"] = highlight_count + room_sync.unread_count = unread_count sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..61d952742d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -31,7 +31,7 @@ from typing import ( from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -280,7 +280,7 @@ class BulkPushRuleEvaluator: # If the event does not have a relation, then cannot have any mutual # relations or thread ID. relations = {} - thread_id = "main" + thread_id = MAIN_TIMELINE if relation: relations = await self._get_mutual_relations( relation.parent_id, diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 658bf373b7..edeba27a45 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,7 +39,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - await concurrently_execute(get_room_unread_count, joins, 10) for notifs in room_notifs: - if notifs.notify_count == 0: + # Combine the counts from all the threads. + notify_count = notifs.main_timeline.notify_count + sum( + n.notify_count for n in notifs.threads.values() + ) + + if notify_count == 0: continue if group_by_room: @@ -47,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge += 1 else: # increment the badge count by the number of unread messages in the room - badge += notifs.notify_count + badge += notify_count return badge diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index c2989765ce..f1c23d68e5 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -509,6 +509,10 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + if room.unread_thread_notifications: + result[ + "org.matrix.msc3773.unread_thread_notifications" + ] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c95b0d6f19..280d306483 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,8 +103,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above - # Support for thread read receipts. + # Support for thread read receipts & notification counts. "org.matrix.msc3771": self.config.experimental.msc3771_enabled, + "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b4469eb964..7bb21f8f81 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", - "event_push_summary": "event_push_summary_unique_index", + "event_push_summary": "event_push_summary_unique_index2", "receipts_linearized": "receipts_linearized_unique_index", "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index cdc9ee5a37..3210e9cca1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -88,7 +88,7 @@ from typing import ( import attr -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction): @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ - The per-user, per-room count of notifications. Used by sync and push. + The per-user, per-room, per-thread count of notifications. Used by sync and push. """ notify_count: int = 0 @@ -165,6 +165,21 @@ class NotifCounts: highlight_count: int = 0 +@attr.s(slots=True, auto_attribs=True) +class RoomNotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + main_timeline: NotifCounts + # Map of thread ID to the notification counts. + threads: Dict[str, NotifCounts] + + def __len__(self) -> int: + # To properly account for the amount of space in any caches. + return len(self.threads) + 1 + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: @@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result - @cached(tree=True, max_entries=5000) + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after their latest read receipt. @@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to retrieve the counts for. Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", @@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: # Get the stream ordering of the user's latest receipt in the room. result = self.get_last_unthreaded_receipt_for_user_txn( txn, @@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas room_id: str, user_id: str, receipt_stream_ordering: int, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt in the room. If there are no receipts, the stream ordering of the user's join event. - Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + Returns: + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ - counts = NotifCounts() + main_counts = NotifCounts() + thread_counts: Dict[str, NotifCounts] = {} + + def _get_thread(thread_id: str) -> NotifCounts: + if thread_id == MAIN_TIMELINE: + return main_counts + return thread_counts.setdefault(thread_id, NotifCounts()) # First we pull the counts from the summary table. # @@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # receipt). txn.execute( """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary WHERE room_id = ? AND user_id = ? AND ( (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) OR last_receipt_stream_ordering = ? - ) + ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), ) - row = txn.fetchone() - - summary_stream_ordering = 0 - if row: - summary_stream_ordering = row[0] - counts.notify_count += row[1] - counts.unread_count += row[2] + max_summary_stream_ordering = 0 + for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count + + # Summaries will only be used if they have not been invalidated by + # a recent receipt; track the latest stream ordering or a valid summary. + # + # Note that since there's only one read receipt in the room per user, + # valid summaries are contiguous. + max_summary_stream_ordering = max( + summary_stream_ordering, max_summary_stream_ordering + ) # Next we need to count highlights, which aren't summarised sql = """ - SELECT COUNT(*) FROM event_push_actions + SELECT COUNT(*), thread_id FROM event_push_actions WHERE user_id = ? AND room_id = ? AND stream_ordering > ? AND highlight = 1 + GROUP BY thread_id """ txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) - row = txn.fetchone() - if row: - counts.highlight_count += row[0] + for highlight_count, thread_id in txn: + _get_thread(thread_id).highlight_count += highlight_count # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. start_unread_stream_ordering = max( - receipt_stream_ordering, summary_stream_ordering + receipt_stream_ordering, max_summary_stream_ordering ) - notify_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, start_unread_stream_ordering ) - counts.notify_count += notify_count - counts.unread_count += unread_count + for notif_count, unread_count, thread_id in unread_counts: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count - return counts + return RoomNotifCounts(main_counts, thread_counts) def _get_notif_unread_count_for_user_room( self, @@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, - ) -> Tuple[int, int]: + ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas If this is not given, then no maximum is applied. Return: - A tuple of the notif count and unread count in the given range. + A tuple of the notif count and unread count in the given range for + each thread. """ # If there have been no events in the room since the stream ordering, # there can't be any push actions either. if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): - return 0, 0 + return [] clause = "" args = [user_id, room_id, stream_ordering] @@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # If the max stream ordering is less than the min stream ordering, # then obviously there are zero push actions in that range. if max_stream_ordering <= stream_ordering: - return 0, 0 + return [] sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), - COUNT(CASE WHEN unread = 1 THEN 1 END) - FROM event_push_actions ea - WHERE user_id = ? + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions ea + WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? {clause} + GROUP BY thread_id """ txn.execute(sql, args) - row = txn.fetchone() - - if row: - return cast(Tuple[int, int], row) - - return 0, 0 + return cast(List[Tuple[int, int, str]], txn.fetchall()) async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. - notif_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) - # Replace the previous summary with the new counts. - # - # TODO(threads): Upsert per-thread instead of setting them all to main. - self.db_pool.simple_upsert_txn( + # First mark the summary for all threads in the room as cleared. + self.db_pool.simple_update_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id}, - values={ - "notif_count": notif_count, - "unread_count": unread_count, + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, - "thread_id": "main", }, ) + # Then any updated threads get their notification count and unread + # count updated. + self.db_pool.simple_update_many_txn( + txn, + table="event_push_summary", + key_names=("room_id", "user_id", "thread_id"), + key_values=[(room_id, user_id, row[2]) for row in unread_counts], + value_names=("notif_count", "unread_count"), + value_values=[(row[0], row[1]) for row in unread_counts], + ) + # We always update `event_push_summary_last_receipt_stream_id` to # ensure that we don't rescan the same receipts for remote users. @@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Calculate the new counts that should be upserted into event_push_summary sql = """ - SELECT user_id, room_id, + SELECT user_id, room_id, thread_id, coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, thread_id, count(*) as cnt, max(ea.stream_ordering) as stream_ordering FROM event_push_actions AS ea - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? AND ( old.last_receipt_stream_ordering IS NULL OR old.last_receipt_stream_ordering < ea.stream_ordering ) AND %s = 1 - GROUP BY user_id, room_id + GROUP BY user_id, room_id, thread_id ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) """ # First get the count of unread messages. @@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries: Dict[Tuple[str, str], _EventPushSummary] = {} + summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {} for row in txn: - summaries[(row[0], row[1])] = _EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], + summaries[(row[0], row[1], row[2])] = _EventPushSummary( + unread_count=row[3], + stream_ordering=row[4], notif_count=0, ) @@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) for row in txn: - if (row[0], row[1]) in summaries: - summaries[(row[0], row[1])].notif_count = row[2] + if (row[0], row[1], row[2]) in summaries: + summaries[(row[0], row[1], row[2])].notif_count = row[3] else: # Because the rules on notifying are different than the rules on marking # a message unread, we might end up with messages that notify but aren't # marked unread, so we might not have a summary for this (user, room) # tuple to complete. - summaries[(row[0], row[1])] = _EventPushSummary( + summaries[(row[0], row[1], row[2])] = _EventPushSummary( unread_count=0, - stream_ordering=row[3], - notif_count=row[2], + stream_ordering=row[4], + notif_count=row[3], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) - # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - key_names=("user_id", "room_id"), - key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), + key_names=("user_id", "room_id", "thread_id"), + key_values=[ + (user_id, room_id, thread_id) + for user_id, room_id, thread_id in summaries + ], + value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, - "main", ) for summary in summaries.values() ], @@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) async def _remove_old_push_actions_that_have_rotated(self) -> None: - """Clear out old push actions that have been summarised.""" + """ + Clear out old push actions that have been summarised (and are older than + 1 day ago). + """ # We want to clear out anything that is older than a day that *has* already # been rotated. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 4a5c947699..19dbf2da7f 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -90,9 +90,9 @@ Changes in SCHEMA_VERSION = 73; SCHEMA_COMPAT_VERSION = ( - # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72 - # could break. - 72 + # The threads_id column must exist for event_push_actions, event_push_summary, + # receipts_linearized, and receipts_graph. + 73 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql new file mode 100644 index 0000000000..0ffde9bbeb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Forces the background updates from 06thread_notifications.sql to run in the +-- foreground as code will now require those to be "done". + +DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; + +-- Overwrite any null thread_id columns. +UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL; + +-- Do not run the event_push_summary_unique_index job if it is pending; the +-- thread_id field will be made required. +DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; +DROP INDEX IF EXISTS event_push_summary_unique_index; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres new file mode 100644 index 0000000000..33674f8c62 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The columns can now be made non-nullable. +ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite new file mode 100644 index 0000000000..5322ad77a4 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite @@ -0,0 +1,101 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- SQLite doesn't support modifying columns to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE event_push_actions_staging_new ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL, + unread SMALLINT, + thread_id TEXT NOT NULL, + inserted_ts BIGINT +); + +CREATE TABLE event_push_actions_new ( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + topological_ordering BIGINT, + stream_ordering BIGINT, + notif SMALLINT, + highlight SMALLINT, + unread SMALLINT, + thread_id TEXT NOT NULL, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) +); + +CREATE TABLE event_push_summary_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + unread_count BIGINT, + last_receipt_stream_ordering BIGINT, + thread_id TEXT NOT NULL +); + +-- Swap the indexes. +DROP INDEX IF EXISTS event_push_actions_staging_id; +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id); + +DROP INDEX IF EXISTS event_push_actions_room_id_user_id; +DROP INDEX IF EXISTS event_push_actions_rm_tokens; +DROP INDEX IF EXISTS event_push_actions_stream_ordering; +DROP INDEX IF EXISTS event_push_actions_u_highlight; +DROP INDEX IF EXISTS event_push_actions_highlights_index; +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering ); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id ); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering); + +-- Copy the data. +INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) + SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts + FROM event_push_actions_staging; + +INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) + SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id + FROM event_push_actions; + +INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) + SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id + FROM event_push_summary; + +-- Drop the old tables. +DROP TABLE event_push_actions_staging; +DROP TABLE event_push_actions; +DROP TABLE event_push_summary; + +-- Rename the tables. +ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; +ALTER TABLE event_push_actions_new RENAME TO event_push_actions; +ALTER TABLE event_push_summary_new RENAME TO event_push_summary; + +-- Re-run background updates from 72/02event_push_actions_index.sql and +-- 72/06thread_notifications.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_summary_unique_index2', '{}') + ON CONFLICT (update_name) DO NOTHING; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_actions_stream_highlight_index', '{}') + ON CONFLICT (update_name) DO NOTHING; diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index efd92793c0..d42e36cdf1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -22,7 +22,10 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import ( + NotifCounts, + RoomNotifCounts, +) from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -178,7 +181,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} + ), ) self.persist( @@ -191,7 +196,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} + ), ) self.persist( @@ -206,7 +213,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), + RoomNotifCounts( + NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} + ), ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 473c965e19..89f986ac34 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -20,6 +20,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.types import JsonDict from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -133,13 +134,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) self.assertEqual( - counts, + counts.main_timeline, NotifCounts( notify_count=noitf_count, unread_count=0, highlight_count=highlight_count, ), ) + self.assertEqual(counts.threads, {}) def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( @@ -186,6 +188,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(0, 0) _create_event() + _assert_counts(1, 0) _rotate() _assert_counts(1, 0) @@ -236,6 +239,168 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0) + def test_count_aggregation_threads(self) -> None: + """ + This is essentially the same test as test_count_aggregation, but adds + events to the main timeline and to a thread. + """ + + user_id, token, _, other_token, room_id = self._create_users_and_room() + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From d8663f5e6358f8eaeda9a3f923fae720a140ca4d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 10:21:16 -0400 Subject: Advertise supporting version 1.3 of the Matrix spec. (#14032) Now that all features / changes in 1.3 are supported in Synapse. --- changelog.d/14032.feature | 1 + synapse/rest/client/versions.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/14032.feature (limited to 'synapse') diff --git a/changelog.d/14032.feature b/changelog.d/14032.feature new file mode 100644 index 0000000000..bb221d3ca6 --- /dev/null +++ b/changelog.d/14032.feature @@ -0,0 +1 @@ +Advertise Matrix 1.3 support on `/_matrix/client/versions`. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 280d306483..18ed313b5c 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -75,6 +75,7 @@ class VersionsRestServlet(RestServlet): "r0.6.1", "v1.1", "v1.2", + "v1.3", ], # as per MSC1497: "unstable_features": { -- cgit 1.5.1 From a7ba457b2b967ca098792d742bc304604b1824b7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 10:46:42 -0400 Subject: Mark events as read using threaded read receipts from MSC3771. (#13877) Applies the proper logic for unthreaded and threaded receipts to either apply to all events in the room or only events in the same thread, respectively. --- changelog.d/13877.feature | 1 + .../storage/databases/main/event_push_actions.py | 277 ++++++++++++++++----- .../73/08thread_receipts_non_null.sql.postgres | 23 ++ .../delta/73/08thread_receipts_non_null.sql.sqlite | 76 ++++++ tests/storage/test_event_push_actions.py | 189 +++++++++++++- 5 files changed, 504 insertions(+), 62 deletions(-) create mode 100644 changelog.d/13877.feature create mode 100644 synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite (limited to 'synapse') diff --git a/changelog.d/13877.feature b/changelog.d/13877.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13877.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3210e9cca1..7469cd336c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -421,7 +421,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - receipt_stream_ordering: int, + unthreaded_receipt_stream_ordering: int, ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -430,9 +430,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: The database transaction. room_id: The room ID to get unread counts for. user_id: The user ID to get unread counts for. - receipt_stream_ordering: The stream ordering of the user's latest - receipt in the room. If there are no receipts, the stream ordering - of the user's join event. + unthreaded_receipt_stream_ordering: The stream ordering of the user's latest + unthreaded receipt in the room. If there are no unthreaded receipts, + the stream ordering of the user's join event. Returns: A RoomNotifCounts object containing the notification count, the @@ -448,71 +448,181 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return main_counts return thread_counts.setdefault(thread_id, NotifCounts()) + receipt_types_clause, receipts_args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + # First we pull the counts from the summary table. # - # We check that `last_receipt_stream_ordering` matches the stream - # ordering given. If it doesn't match then a new read receipt has arrived and - # we haven't yet updated the counts in `event_push_summary` to reflect - # that; in that case we simply ignore `event_push_summary` counts - # and do a manual count of all of the rows in the `event_push_actions` table - # for this user/room. + # We check that `last_receipt_stream_ordering` matches the stream ordering of the + # latest receipt for the thread (which may be either the unthreaded read receipt + # or the threaded read receipt). # - # If `last_receipt_stream_ordering` is null then that means it's up to - # date (as the row was written by an older version of Synapse that + # If it doesn't match then a new read receipt has arrived and we haven't yet + # updated the counts in `event_push_summary` to reflect that; in that case we + # simply ignore `event_push_summary` counts. + # + # We then do a manual count of all the rows in the `event_push_actions` table + # for any user/room/thread which did not have a valid summary found. + # + # If `last_receipt_stream_ordering` is null then that means it's up-to-date + # (as the row was written by an older version of Synapse that # updated `event_push_summary` synchronously when persisting a new read # receipt). txn.execute( - """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id + f""" + SELECT notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE room_id = ? AND user_id = ? AND ( - (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) - OR last_receipt_stream_ordering = ? + (last_receipt_stream_ordering IS NULL AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)) + OR last_receipt_stream_ordering = COALESCE(threaded_receipt_stream_ordering, ?) ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, - (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + room_id, + user_id, + unthreaded_receipt_stream_ordering, + unthreaded_receipt_stream_ordering, + ), ) - max_summary_stream_ordering = 0 - for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + summarised_threads = set() + for notif_count, unread_count, thread_id in txn: + summarised_threads.add(thread_id) counts = _get_thread(thread_id) counts.notify_count += notif_count counts.unread_count += unread_count - # Summaries will only be used if they have not been invalidated by - # a recent receipt; track the latest stream ordering or a valid summary. - # - # Note that since there's only one read receipt in the room per user, - # valid summaries are contiguous. - max_summary_stream_ordering = max( - summary_stream_ordering, max_summary_stream_ordering - ) - # Next we need to count highlights, which aren't summarised - sql = """ + sql = f""" SELECT COUNT(*), thread_id FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE user_id = ? AND room_id = ? - AND stream_ordering > ? + AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?) AND highlight = 1 GROUP BY thread_id """ - txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) + txn.execute( + sql, + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + user_id, + room_id, + unthreaded_receipt_stream_ordering, + ), + ) for highlight_count, thread_id in txn: _get_thread(thread_id).highlight_count += highlight_count + # For threads which were summarised we need to count actions since the last + # rotation. + thread_id_clause, thread_id_args = make_in_list_sql_clause( + self.database_engine, "thread_id", summarised_threads + ) + + # The (inclusive) event stream ordering that was previously summarised. + rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + unread_counts = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, rotated_upto_stream_ordering + ) + for notif_count, unread_count, thread_id in unread_counts: + if thread_id not in summarised_threads: + continue + + if thread_id == MAIN_TIMELINE: + counts.notify_count += notif_count + counts.unread_count += unread_count + elif thread_id in thread_counts: + thread_counts[thread_id].notify_count += notif_count + thread_counts[thread_id].unread_count += unread_count + else: + # Previous thread summaries of 0 are discarded above. + # + # TODO If empty summaries are deleted this can be removed. + thread_counts[thread_id] = NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=0, + ) + # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. - start_unread_stream_ordering = max( - receipt_stream_ordering, max_summary_stream_ordering - ) - unread_counts = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, start_unread_stream_ordering + sql = f""" + SELECT + COUNT(CASE WHEN notif = 1 THEN 1 END), + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) + WHERE user_id = ? + AND room_id = ? + AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?) + AND NOT {thread_id_clause} + GROUP BY thread_id + """ + txn.execute( + sql, + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *thread_id_args, + ), ) - - for notif_count, unread_count, thread_id in unread_counts: + for notif_count, unread_count, thread_id in txn: counts = _get_thread(thread_id) counts.notify_count += notif_count counts.unread_count += unread_count @@ -526,6 +636,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, + thread_id: Optional[str] = None, ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -540,6 +651,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas stream_ordering: The (exclusive) minimum stream ordering to consider. max_stream_ordering: The (inclusive) maximum stream ordering to consider. If this is not given, then no maximum is applied. + thread_id: The thread ID to fetch unread counts for. If this is not provided + then the results for *all* threads is returned. + + Note that if this is provided the resulting list will only have 0 or + 1 tuples in it. Return: A tuple of the notif count and unread count in the given range for @@ -551,10 +667,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): return [] - clause = "" + stream_ordering_clause = "" args = [user_id, room_id, stream_ordering] if max_stream_ordering is not None: - clause = "AND ea.stream_ordering <= ?" + stream_ordering_clause = "AND ea.stream_ordering <= ?" args.append(max_stream_ordering) # If the max stream ordering is less than the min stream ordering, @@ -562,6 +678,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas if max_stream_ordering <= stream_ordering: return [] + # Either limit the results to a specific thread or fetch all threads. + thread_id_clause = "" + if thread_id is not None: + thread_id_clause = "AND thread_id = ?" + args.append(thread_id) + sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), @@ -571,7 +693,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? - {clause} + {stream_ordering_clause} + {thread_id_clause} GROUP BY thread_id """ @@ -1086,7 +1209,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = """ - SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering + SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering FROM receipts_linearized AS r INNER JOIN events AS e USING (event_id) WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? @@ -1107,45 +1230,69 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas limit, ), ) - rows = cast(List[Tuple[int, str, str, int]], txn.fetchall()) + rows = cast(List[Tuple[int, str, str, Optional[str], int]], txn.fetchall()) # For each new read receipt we delete push actions from before it and # recalculate the summary. - for _, room_id, user_id, stream_ordering in rows: + # + # Care must be taken of whether it is a threaded or unthreaded receipt. + for _, room_id, user_id, thread_id, stream_ordering in rows: # Only handle our own read receipts. if not self.hs.is_mine_id(user_id): continue + thread_clause = "" + thread_args: Tuple = () + if thread_id is not None: + thread_clause = "AND thread_id = ?" + thread_args = (thread_id,) + + # For each new read receipt we delete push actions from before it and + # recalculate the summary. txn.execute( - """ + f""" DELETE FROM event_push_actions WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? AND highlight = 0 + {thread_clause} """, - (room_id, user_id, stream_ordering), + (room_id, user_id, stream_ordering, *thread_args), ) # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. unread_counts = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering - ) - - # First mark the summary for all threads in the room as cleared. - self.db_pool.simple_update_txn( txn, - table="event_push_summary", - keyvalues={"user_id": user_id, "room_id": room_id}, - updatevalues={ - "notif_count": 0, - "unread_count": 0, - "stream_ordering": old_rotate_stream_ordering, - "last_receipt_stream_ordering": stream_ordering, - }, + room_id, + user_id, + stream_ordering, + old_rotate_stream_ordering, + thread_id, ) + # For an unthreaded receipt, mark the summary for all threads in the room + # as cleared. + if thread_id is None: + self.db_pool.simple_update_txn( + txn, + table="event_push_summary", + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, + "stream_ordering": old_rotate_stream_ordering, + "last_receipt_stream_ordering": stream_ordering, + }, + ) + + # For a threaded receipt, we *always* want to update that receipt, + # event if there are no new notifications in that thread. This ensures + # the stream_ordering & last_receipt_stream_ordering are updated. + elif not unread_counts: + unread_counts = [(0, 0, thread_id)] + # Then any updated threads get their notification count and unread # count updated. self.db_pool.simple_update_many_txn( @@ -1153,8 +1300,16 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas table="event_push_summary", key_names=("room_id", "user_id", "thread_id"), key_values=[(room_id, user_id, row[2]) for row in unread_counts], - value_names=("notif_count", "unread_count"), - value_values=[(row[0], row[1]) for row in unread_counts], + value_names=( + "notif_count", + "unread_count", + "stream_ordering", + "last_receipt_stream_ordering", + ), + value_values=[ + (row[0], row[1], old_rotate_stream_ordering, stream_ordering) + for row in unread_counts + ], ) # We always update `event_push_summary_last_receipt_stream_id` to diff --git a/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres new file mode 100644 index 0000000000..3e0bc9e5eb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Drop constraint on (room_id, receipt_type, user_id). + +-- Rebuild the unique constraint with the thread_id. +ALTER TABLE receipts_linearized + DROP CONSTRAINT receipts_linearized_uniqueness; + +ALTER TABLE receipts_graph + DROP CONSTRAINT receipts_graph_uniqueness; diff --git a/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite new file mode 100644 index 0000000000..e664889fbc --- /dev/null +++ b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite @@ -0,0 +1,76 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Drop constraint on (room_id, receipt_type, user_id). +-- +-- SQLite doesn't support modifying constraints to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE receipts_linearized_new ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + thread_id TEXT, + event_stream_ordering BIGINT, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +CREATE TABLE receipts_graph_new ( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + thread_id TEXT, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +-- Drop the old indexes. +DROP INDEX IF EXISTS receipts_linearized_id; +DROP INDEX IF EXISTS receipts_linearized_room_stream; +DROP INDEX IF EXISTS receipts_linearized_user; + +-- Copy the data. +INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, data) + SELECT stream_id, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized; +INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data) + SELECT room_id, receipt_type, user_id, event_ids, data + FROM receipts_graph; + +-- Drop the old tables. +DROP TABLE receipts_linearized; +DROP TABLE receipts_graph; + +-- Rename the tables. +ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized; +ALTER TABLE receipts_graph_new RENAME TO receipts_graph; + +-- Create the indices. +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); + +-- Re-run background updates from 72/08thread_receipts.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7308, 'receipts_linearized_unique_index', '{}') + ON CONFLICT (update_name) DO NOTHING; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7308, 'receipts_graph_unique_index', '{}') + ON CONFLICT (update_name) DO NOTHING; diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 89f986ac34..6fa0cafb75 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -16,6 +16,7 @@ from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import MAIN_TIMELINE from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -312,7 +313,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): def _rotate() -> None: self.get_success(self.store._rotate_notifs()) - def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + def _mark_read(event_id: str, thread_id: str = MAIN_TIMELINE) -> None: self.get_success( self.store.insert_receipt( room_id, @@ -348,9 +349,12 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _create_event() _create_event(thread_id=thread_id) _mark_read(event_id) + _assert_counts(1, 0, 3, 0) + _mark_read(event_id, thread_id) _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event() @@ -364,6 +368,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event(True) @@ -389,8 +394,190 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): # Check that sending read receipts at different points results in the # right counts. _mark_read(event_id) + _assert_counts(1, 0, 2, 1) + _mark_read(event_id, thread_id) _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _assert_counts(0, 0, 1, 0) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + + def test_count_aggregation_mixed(self) -> None: + """ + This is essentially the same test as test_count_aggregation_threads, but + sends both unthreaded and threaded receipts. + """ + + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id, MAIN_TIMELINE) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(event_id, MAIN_TIMELINE) + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id, MAIN_TIMELINE) + _assert_counts(0, 0, 1, 0) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event(True) -- cgit 1.5.1 From 2b6d41ebd685fb546e52acdbcb0024dfcf5a5db1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 11:36:16 -0400 Subject: Recursively fetch the thread for receipts & notifications. (#13824) Consider an event to be part of a thread if you can follow a chain of relations up to a thread root. Part of MSC3773 & MSC3771. --- changelog.d/13824.feature | 1 + synapse/push/bulk_push_rule_evaluator.py | 5 ++ synapse/rest/client/receipts.py | 22 +++++- synapse/storage/databases/main/relations.py | 36 ++++++++++ tests/storage/test_event_push_actions.py | 100 ++++++++++++++++++++++++++++ 5 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13824.feature (limited to 'synapse') diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13824.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 61d952742d..f8c4dd74f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -286,8 +286,13 @@ class BulkPushRuleEvaluator: relation.parent_id, itertools.chain(*(r.rules() for r in rules_by_user.values())), ) + # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + else: + # Since the event has not yet been persisted we check whether + # the parent is part of a thread. + thread_id = await self.store.get_thread_id(relation.parent_id) or "main" evaluator = PushRuleEvaluator( _flatten_dict(event), diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index f3ff156abe..287dfdd69e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._main_store = hs.get_datastores().main self._known_receipt_types = { ReceiptTypes.READ, @@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet): thread_id = body.get("thread_id") if not thread_id or not isinstance(thread_id, str): raise SynapseError( - 400, "thread_id field must be a non-empty string" + 400, + "thread_id field must be a non-empty string", + Codes.INVALID_PARAM, + ) + + if receipt_type == ReceiptTypes.FULLY_READ: + raise SynapseError( + 400, + f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", + Codes.INVALID_PARAM, + ) + + # Ensure the event ID roughly correlates to the thread ID. + if thread_id != await self._main_store.get_thread_id(event_id): + raise SynapseError( + 400, + f"event_id {event_id} is not related to thread {thread_id}", + Codes.INVALID_PARAM, ) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 898947af95..154385b1e8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_event_relations", _get_event_relations ) + @cached() + async def get_thread_id(self, event_id: str) -> Optional[str]: + """ + Get the thread ID for an event. This considers multi-level relations, + e.g. an annotation to an event which is part of a thread. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. None, otherwise. + """ + # Since event relations form a tree, we should only ever find 0 or 1 + # results from the below query. + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + """ + + def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + txn.execute(sql, (event_id,)) + # TODO Should we ensure there's only a single result here? + row = txn.fetchone() + if row: + return row[0] + return None + + return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 6fa0cafb75..886585e9f2 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0, 0, 0) + def test_recursive_thread(self) -> None: + """ + Events related to events in a thread should still be considered part of + that thread. + """ + + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + + # Update the user's push rules to care about reaction events. + self.get_success( + self.store.add_push_rule( + user_id, + "related_events", + priority_class=5, + conditions=[ + {"kind": "event_match", "key": "type", "pattern": "m.reaction"} + ], + actions=["notify"], + ) + ) + + def _create_event(type: str, content: JsonDict) -> str: + result = self.helper.send_event( + room_id, type=type, content=content, tok=other_token + ) + return result["event_id"] + + def _assert_counts(noitf_count: int, thread_notif_count: int) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, unread_count=0, highlight_count=0 + ), + ) + if thread_notif_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=0, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + # Create a root event. + thread_id = _create_event( + "m.room.message", {"msgtype": "m.text", "body": "msg"} + ) + _assert_counts(1, 0) + + # Reply, creating a thread. + reply_id = _create_event( + "m.room.message", + { + "msgtype": "m.text", + "body": "msg", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": thread_id, + }, + }, + ) + _assert_counts(1, 1) + + # Create an event related to a thread event, this should still appear in + # the thread. + _create_event( + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": "m.annotation", + "event_id": reply_id, + "key": "A", + } + }, + ) + _assert_counts(1, 2) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From 0506bb100e0348ab6e6e213c6624677a83ef9303 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 4 Oct 2022 16:42:59 +0100 Subject: Remove get rooms for user with stream ordering (#13991) By getting the joined rooms before the current token we avoid any reading history to confirm a user *was* in a room. We can then use any membership change events, which we already fetch during sync, to determine the final list of joined room IDs. --- changelog.d/13991.misc | 1 + synapse/handlers/sync.py | 149 ++++++++++++++++++++++------------------------- 2 files changed, 70 insertions(+), 80 deletions(-) create mode 100644 changelog.d/13991.misc (limited to 'synapse') diff --git a/changelog.d/13991.misc b/changelog.d/13991.misc new file mode 100644 index 0000000000..f425fb17b2 --- /dev/null +++ b/changelog.d/13991.misc @@ -0,0 +1 @@ +Optimise queries used to get a users rooms during sync. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 329e89c604..0f684857ca 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1317,6 +1317,19 @@ class SyncHandler: At the end, we transfer data from the `sync_result_builder` to a new `SyncResult` instance to signify that the sync calculation is complete. """ + + user_id = sync_config.user.to_string() + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service: + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() + + # Note: we get the users room list *before* we get the current token, this + # avoids checking back in history if rooms are joined after the token is fetched. + token_before_rooms = self.event_sources.get_current_token() + mutable_joined_room_ids = set(await self.store.get_rooms_for_user(user_id)) + # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1324,6 +1337,57 @@ class SyncHandler: now_token = self.event_sources.get_current_token() log_kv({"now_token": now_token}) + # Since we fetched the users room list before the token, there's a small window + # during which membership events may have been persisted, so we fetch these now + # and modify the joined room list for any changes between the get_rooms_for_user + # call and the get_current_token call. + membership_change_events = [] + if since_token: + membership_change_events = await self.store.get_membership_changes_for_user( + user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude + ) + + mem_last_change_by_room_id: Dict[str, EventBase] = {} + for event in membership_change_events: + mem_last_change_by_room_id[event.room_id] = event + + # For the latest membership event in each room found, add/remove the room ID + # from the joined room list accordingly. In this case we only care if the + # latest change is JOIN. + + for room_id, event in mem_last_change_by_room_id.items(): + assert event.internal_metadata.stream_ordering + if ( + event.internal_metadata.stream_ordering + < token_before_rooms.room_key.stream + ): + continue + + logger.info( + "User membership change between getting rooms and current token: %s %s %s", + user_id, + event.membership, + room_id, + ) + # User joined a room - we have to then check the room state to ensure we + # respect any bans if there's a race between the join and ban events. + if event.membership == Membership.JOIN: + user_ids_in_room = await self.store.get_users_in_room(room_id) + if user_id in user_ids_in_room: + mutable_joined_room_ids.add(room_id) + # The user left the room, or left and was re-invited but not joined yet + else: + mutable_joined_room_ids.discard(room_id) + + # Now we have our list of joined room IDs, exclude as configured and freeze + joined_room_ids = frozenset( + ( + room_id + for room_id in mutable_joined_room_ids + if room_id not in self.rooms_to_exclude + ) + ) + logger.debug( "Calculating sync response for %r between %s and %s", sync_config.user, @@ -1331,22 +1395,13 @@ class SyncHandler: now_token, ) - user_id = sync_config.user.to_string() - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service: - # We no longer support AS users using /sync directly. - # See https://github.com/matrix-org/matrix-doc/issues/1144 - raise NotImplementedError() - else: - joined_room_ids = await self.get_rooms_for_user_at( - user_id, now_token.room_key - ) sync_result_builder = SyncResultBuilder( sync_config, full_state, since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, + membership_change_events=membership_change_events, ) logger.debug("Fetching account data") @@ -1827,19 +1882,12 @@ class SyncHandler: Does not modify the `sync_result_builder`. """ - user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token - now_token = sync_result_builder.now_token + membership_change_events = sync_result_builder.membership_change_events assert since_token - # Get a list of membership change events that have happened to the user - # requesting the sync. - membership_changes = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key - ) - - if membership_changes: + if membership_change_events: return True stream_id = since_token.room_key.stream @@ -1878,16 +1926,10 @@ class SyncHandler: since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config + membership_change_events = sync_result_builder.membership_change_events assert since_token - # TODO: we've already called this function and ran this query in - # _have_rooms_changed. We could keep the results in memory to avoid a - # second query, at the cost of more complicated source code. - membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude - ) - mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} for event in membership_change_events: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -2415,60 +2457,6 @@ class SyncHandler: else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - async def get_rooms_for_user_at( - self, - user_id: str, - room_key: RoomStreamToken, - ) -> FrozenSet[str]: - """Get set of joined rooms for a user at the given stream ordering. - - The stream ordering *must* be recent, otherwise this may throw an - exception if older than a month. (This function is called with the - current token, which should be perfectly fine). - - Args: - user_id - stream_ordering - - ReturnValue: - Set of room_ids the user is in at given stream_ordering. - """ - joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) - - joined_room_ids = set() - - # We need to check that the stream ordering of the join for each room - # is before the stream_ordering asked for. This might not be the case - # if the user joins a room between us getting the current token and - # calling `get_rooms_for_user_with_stream_ordering`. - # If the membership's stream ordering is after the given stream - # ordering, we need to go and work out if the user was in the room - # before. - # We also need to check whether the room should be excluded from sync - # responses as per the homeserver config. - for joined_room in joined_rooms: - if joined_room.room_id in self.rooms_to_exclude: - continue - - if not joined_room.event_pos.persisted_after(room_key): - joined_room_ids.add(joined_room.room_id) - continue - - logger.info("User joined room after current token: %s", joined_room.room_id) - - extrems = ( - await self.store.get_forward_extremities_for_room_at_stream_ordering( - joined_room.room_id, joined_room.event_pos.stream - ) - ) - user_ids_in_room = await self.state.get_current_user_ids_in_room( - joined_room.room_id, extrems - ) - if user_id in user_ids_in_room: - joined_room_ids.add(joined_room.room_id) - - return frozenset(joined_room_ids) - def _action_has_highlight(actions: List[JsonDict]) -> bool: for action in actions: @@ -2565,6 +2553,7 @@ class SyncResultBuilder: since_token: Optional[StreamToken] now_token: StreamToken joined_room_ids: FrozenSet[str] + membership_change_events: List[EventBase] presence: List[UserPresenceState] = attr.Factory(list) account_data: List[JsonDict] = attr.Factory(list) -- cgit 1.5.1 From dcced5a8d76b94e372aefa7d1f05ec0dbc22ea0d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 12:07:02 -0400 Subject: Use threaded receipts when fetching events for push. (#13878) Update the HTTP and email pushers to consider threaded read receipts when fetching unread events. --- changelog.d/13878.feature | 1 + .../storage/databases/main/event_push_actions.py | 80 +++++++++++++++------- tests/storage/test_event_push_actions.py | 57 ++++++++++----- 3 files changed, 97 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13878.feature (limited to 'synapse') diff --git a/changelog.d/13878.feature b/changelog.d/13878.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13878.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7469cd336c..332e13d1c9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [ ] +@attr.s(slots=True, auto_attribs=True) +class _RoomReceipt: + """ + HttpPushAction instances include the information used to generate HTTP + requests to a push gateway. + """ + + unthreaded_stream_ordering: int = 0 + # threaded_stream_ordering includes the main pseudo-thread. + threaded_stream_ordering: Dict[str, int] = attr.Factory(dict) + + def is_unread(self, thread_id: str, stream_ordering: int) -> bool: + """Returns True if the stream ordering is unread according to the receipt information.""" + + # Only include push actions with a stream ordering after both the unthreaded + # and threaded receipt. Properly handles a user without any receipts present. + return ( + self.unthreaded_stream_ordering < stream_ordering + and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering + ) + + +# A _RoomReceipt with no receipts in it. +MISSING_ROOM_RECEIPT = _RoomReceipt() + + @attr.s(slots=True, frozen=True, auto_attribs=True) class HttpPushAction: """ @@ -716,7 +742,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> Dict[str, int]: + ) -> Dict[str, _RoomReceipt]: """ Generate a map of room ID to the latest stream ordering that has been read by the given user. @@ -726,7 +752,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to fetch receipts for. Returns: - A map of room ID to stream ordering for all rooms the user has a receipt in. + A map including all rooms the user is in with a receipt. It maps + room IDs to _RoomReceipt instances """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, @@ -735,20 +762,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = f""" - SELECT room_id, MAX(stream_ordering) + SELECT room_id, thread_id, MAX(stream_ordering) FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id + GROUP BY room_id, thread_id """ args.extend((user_id,)) txn.execute(sql, args) - return { - room_id: latest_stream_ordering - for room_id, latest_stream_ordering in txn.fetchall() - } + + result: Dict[str, _RoomReceipt] = {} + for room_id, thread_id, stream_ordering in txn: + room_receipt = result.setdefault(room_id, _RoomReceipt()) + if thread_id is None: + room_receipt.unthreaded_stream_ordering = stream_ordering + else: + room_receipt.threaded_stream_ordering[thread_id] = stream_ordering + + return result async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -781,9 +814,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_push_actions_txn( txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool]]: + ) -> List[Tuple[str, str, str, int, str, bool]]: sql = """ - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight + SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, + ep.actions, ep.highlight FROM event_push_actions AS ep WHERE ep.user_id = ? @@ -793,7 +827,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ORDER BY ep.stream_ordering ASC LIMIT ? """ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) + return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall()) push_actions = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn @@ -806,10 +840,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas stream_ordering=stream_ordering, actions=_deserialize_action(actions, highlight), ) - for event_id, room_id, stream_ordering, actions, highlight in push_actions - # Only include push actions with a stream ordering after any receipt, or without any - # receipt present (invited to but never read rooms). - if stream_ordering > receipts_by_room.get(room_id, 0) + for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions + if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( + thread_id, stream_ordering + ) ] # Now sort it so it's ordered correctly, since currently it will @@ -853,10 +887,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_push_actions_txn( txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool, int]]: + ) -> List[Tuple[str, str, str, int, str, bool, int]]: sql = """ - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight, e.received_ts + SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, + ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -867,7 +901,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ORDER BY ep.stream_ordering DESC LIMIT ? """ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) + return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall()) push_actions = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn @@ -882,10 +916,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas actions=_deserialize_action(actions, highlight), received_ts=received_ts, ) - for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions - # Only include push actions with a stream ordering after any receipt, or without any - # receipt present (invited to but never read rooms). - if stream_ordering > receipts_by_room.get(room_id, 0) + for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions + if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( + thread_id, stream_ordering + ) ] # Now sort it so it's ordered correctly, since currently it will diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 886585e9f2..ee48920f84 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -16,7 +16,7 @@ from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import MAIN_TIMELINE +from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -66,16 +66,23 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): user_id, token, _, other_token, room_id = self._create_users_and_room() # Create two events, one of which is a highlight. - self.helper.send_event( + first_event_id = self.helper.send_event( room_id, type="m.room.message", content={"msgtype": "m.text", "body": "msg"}, tok=other_token, - ) - event_id = self.helper.send_event( + )["event_id"] + second_event_id = self.helper.send_event( room_id, type="m.room.message", - content={"msgtype": "m.text", "body": user_id}, + content={ + "msgtype": "m.text", + "body": user_id, + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": first_event_id, + }, + }, tok=other_token, )["event_id"] @@ -95,13 +102,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) self.assertEqual(2, len(email_actions)) - # Send a receipt, which should clear any actions. + # Send a receipt, which should clear the first action. self.get_success( self.store.insert_receipt( room_id, "m.read", user_id=user_id, - event_ids=[event_id], + event_ids=[first_event_id], thread_id=None, data={}, ) @@ -111,6 +118,30 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): user_id, 0, 1000, 20 ) ) + self.assertEqual(1, len(http_actions)) + email_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_email( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual(1, len(email_actions)) + + # Send a thread receipt to clear the thread action. + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[second_event_id], + thread_id=first_event_id, + data={}, + ) + ) + http_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_http( + user_id, 0, 1000, 20 + ) + ) self.assertEqual([], http_actions) email_actions = self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_email( @@ -417,17 +448,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): sends both unthreaded and threaded receipts. """ - # Create a user to receive notifications and send receipts. - user_id = self.register_user("user1235", "pass") - token = self.login("user1235", "pass") - - # And another users to send events. - other_id = self.register_user("other", "pass") - other_token = self.login("other", "pass") - - # Create a room and put both users in it. - room_id = self.helper.create_room_as(user_id, tok=token) - self.helper.join(room_id, other_id, tok=other_token) + user_id, token, _, other_token, room_id = self._create_users_and_room() thread_id: str last_event_id: str -- cgit 1.5.1 From e3d475545467fe587d906d755d8471acbad11266 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 5 Oct 2022 07:56:05 -0400 Subject: Fix backwards compatibility with upcoming threads schema changes. (#14045) Ensure that the upsert will work properly by first updating any existing rows (in the same way that the background update to backfill data works). --- changelog.d/14045.misc | 1 + .../storage/databases/main/event_push_actions.py | 34 +++++++++++++++------- 2 files changed, 24 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14045.misc (limited to 'synapse') diff --git a/changelog.d/14045.misc b/changelog.d/14045.misc new file mode 100644 index 0000000000..0b0dd8f47a --- /dev/null +++ b/changelog.d/14045.misc @@ -0,0 +1 @@ +Ensure Synapse v1.69 works with upcoming database changes in v1.70. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index cdc9ee5a37..c9724d7345 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1103,19 +1103,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) + # First ensure that the existing rows have an updated thread_id field. + self.db_pool.simple_update_txn( + txn, + table="event_push_summary", + keyvalues={"room_id": room_id, "user_id": user_id, "thread_id": None}, + updatevalues={"thread_id": "main"}, + ) + # Replace the previous summary with the new counts. # # TODO(threads): Upsert per-thread instead of setting them all to main. self.db_pool.simple_upsert_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id}, + keyvalues={"room_id": room_id, "user_id": user_id, "thread_id": "main"}, values={ "notif_count": notif_count, "unread_count": unread_count, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, - "thread_id": "main", }, ) @@ -1264,20 +1271,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # Ensure that any updated threads have an updated thread_id. + self.db_pool.simple_update_many_txn( + txn, + table="event_push_summary", + key_names=("user_id", "room_id", "thread_id"), + key_values=[(user_id, room_id, None) for user_id, room_id in summaries], + value_names=("thread_id",), + value_values=[("main",) for _ in summaries], + ) + # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - key_names=("user_id", "room_id"), - key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), + key_names=("user_id", "room_id", "thread_id"), + key_values=[(user_id, room_id, "main") for user_id, room_id in summaries], + value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ - ( - summary.notif_count, - summary.unread_count, - summary.stream_ordering, - "main", - ) + (summary.notif_count, summary.unread_count, summary.stream_ordering) for summary in summaries.values() ], ) -- cgit 1.5.1 From 0b037d6c918cb04f86b1fccae9610552de9386d7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 5 Oct 2022 08:49:52 -0400 Subject: Fix handling of public rooms filter with a network tuple. (#14053) Fixes two related bugs: * The handling of `[null]` for a `room_types` filter was incorrect. * The ordering of arguments when providing both a network tuple and room type field was incorrect. --- changelog.d/14053.bugfix | 1 + synapse/storage/databases/main/room.py | 43 ++++++++++++++++++++-------------- tests/rest/client/test_rooms.py | 41 ++++++++++++++++++++++++-------- 3 files changed, 58 insertions(+), 27 deletions(-) create mode 100644 changelog.d/14053.bugfix (limited to 'synapse') diff --git a/changelog.d/14053.bugfix b/changelog.d/14053.bugfix new file mode 100644 index 0000000000..07769f51d0 --- /dev/null +++ b/changelog.d/14053.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.53.0 when querying `/publicRooms` with both a `room_type` filter and a `third_party_instance_id`. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7412bce255..e41c99027a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _construct_room_type_where_clause( self, room_types: Union[List[Union[str, None]], None] - ) -> Tuple[Union[str, None], List[str]]: + ) -> Tuple[Union[str, None], list]: if not room_types: return None, [] - else: - # We use None when we want get rooms without a type - is_null_clause = "" - if None in room_types: - is_null_clause = "OR room_type IS NULL" - room_types = [value for value in room_types if value is not None] + # Since None is used to represent a room without a type, care needs to + # be taken into account when constructing the where clause. + clauses = [] + args: list = [] + + room_types_set = set(room_types) + + # We use None to represent a room without a type. + if None in room_types_set: + clauses.append("room_type IS NULL") + room_types_set.remove(None) + + # If there are other room types, generate the proper clause. + if room_types: list_clause, args = make_in_list_sql_clause( - self.database_engine, "room_type", room_types + self.database_engine, "room_type", room_types_set ) + clauses.append(list_clause) - return f"({list_clause} {is_null_clause})", args + return f"({' OR '.join(clauses)})", args async def count_public_rooms( self, @@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] - room_type_clause, args = self._construct_room_type_where_clause( - search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) - if search_filter - else None - ) - room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" - query_args += args - if network_tuple: if network_tuple.appservice_id: published_sql = """ @@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" + query_args += args + sql = f""" SELECT COUNT(*) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 5e66b5b26c..3612ebe7b9 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2213,14 +2213,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ) def make_public_rooms_request( - self, room_types: Union[List[Union[str, None]], None] + self, + room_types: Optional[List[Union[str, None]]], + instance_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: - channel = self.make_request( - "POST", - self.url, - {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, - self.token, - ) + body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}} + if instance_id: + body["third_party_instance_id"] = "test|test" + + channel = self.make_request("POST", self.url, body, self.token) + self.assertEqual(channel.code, 200) + chunk = channel.json_body["chunk"] count = channel.json_body["total_room_count_estimate"] @@ -2230,31 +2233,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: chunk, count = self.make_public_rooms_request(None) - self.assertEqual(count, 2) + # Also check if there's no filter property at all in the body. + channel = self.make_request("POST", self.url, {}, self.token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["chunk"]), 2) + self.assertEqual(channel.json_body["total_room_count_estimate"], 2) + + chunk, count = self.make_public_rooms_request(None, "test|test") + self.assertEqual(count, 0) + def test_returns_only_rooms_based_on_filter(self) -> None: chunk, count = self.make_public_rooms_request([None]) self.assertEqual(count, 1) self.assertEqual(chunk[0].get("room_type", None), None) + chunk, count = self.make_public_rooms_request([None], "test|test") + self.assertEqual(count, 0) + def test_returns_only_space_based_on_filter(self) -> None: chunk, count = self.make_public_rooms_request(["m.space"]) self.assertEqual(count, 1) self.assertEqual(chunk[0].get("room_type", None), "m.space") + chunk, count = self.make_public_rooms_request(["m.space"], "test|test") + self.assertEqual(count, 0) + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: chunk, count = self.make_public_rooms_request(["m.space", None]) - self.assertEqual(count, 2) + chunk, count = self.make_public_rooms_request(["m.space", None], "test|test") + self.assertEqual(count, 0) + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: chunk, count = self.make_public_rooms_request([]) - self.assertEqual(count, 2) + chunk, count = self.make_public_rooms_request([], "test|test") + self.assertEqual(count, 0) + class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): """Test that we correctly fallback to local filtering if a remote server -- cgit 1.5.1 From 7b7478e8b65cceb9e7362c6c1cb932b569a6f383 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 5 Oct 2022 10:12:48 -0700 Subject: Batch up notifications after event persistence (#14033) --- changelog.d/14033.misc | 1 + synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 25 ++++++------ synapse/notifier.py | 75 ++++++++++++++++++++---------------- synapse/replication/tcp/client.py | 19 ++++----- 5 files changed, 66 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14033.misc (limited to 'synapse') diff --git a/changelog.d/14033.misc b/changelog.d/14033.misc new file mode 100644 index 0000000000..fe42852aa5 --- /dev/null +++ b/changelog.d/14033.misc @@ -0,0 +1 @@ +Don't repeatedly wake up the same users for batched events. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 778d8869b3..da319943cc 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2240,8 +2240,8 @@ class FederationEventHandler: event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - await self._notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users + await self._notifier.on_new_room_events( + [(event, event_pos)], max_stream_token, extra_users=extra_users ) if event.type == EventTypes.Member and event.membership == Membership.JOIN: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 00e7645ba5..da1acea275 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1872,6 +1872,7 @@ class EventCreationHandler: events_and_context, backfilled=backfilled ) + events_and_pos = [] for event in persisted_events: if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. @@ -1880,25 +1881,23 @@ class EventCreationHandler: stream_ordering = event.internal_metadata.stream_ordering assert stream_ordering is not None pos = PersistedEventPosition(self._instance_name, stream_ordering) - - async def _notify() -> None: - try: - await self.notifier.on_new_room_event( - event, pos, max_stream_token, extra_users=extra_users - ) - except Exception: - logger.exception( - "Error notifying about new room event %s", - event.event_id, - ) - - run_in_background(_notify) + events_and_pos.append((event, pos)) if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + async def _notify() -> None: + try: + await self.notifier.on_new_room_events( + events_and_pos, max_stream_token, extra_users=extra_users + ) + except Exception: + logger.exception("Error notifying about new room events") + + run_in_background(_notify) + return persisted_events[-1] async def _maybe_kick_guest_users( diff --git a/synapse/notifier.py b/synapse/notifier.py index c42bb8266a..26b97cf766 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -294,35 +294,31 @@ class Notifier: """ self._new_join_in_room_callbacks.append(cb) - async def on_new_room_event( + async def on_new_room_events( self, - event: EventBase, - event_pos: PersistedEventPosition, + events_and_pos: List[Tuple[EventBase, PersistedEventPosition]], max_room_stream_token: RoomStreamToken, extra_users: Optional[Collection[UserID]] = None, ) -> None: - """Unwraps event and calls `on_new_room_event_args`.""" - await self.on_new_room_event_args( - event_pos=event_pos, - room_id=event.room_id, - event_id=event.event_id, - event_type=event.type, - state_key=event.get("state_key"), - membership=event.content.get("membership"), - max_room_stream_token=max_room_stream_token, - extra_users=extra_users or [], - ) + """Creates a _PendingRoomEventEntry for each of the listed events and calls + notify_new_room_events with the results.""" + event_entries = [] + for event, pos in events_and_pos: + entry = self.create_pending_room_event_entry( + pos, + extra_users, + event.room_id, + event.type, + event.get("state_key"), + event.content.get("membership"), + ) + event_entries.append((entry, event.event_id)) + await self.notify_new_room_events(event_entries, max_room_stream_token) - async def on_new_room_event_args( + async def notify_new_room_events( self, - room_id: str, - event_id: str, - event_type: str, - state_key: Optional[str], - membership: Optional[str], - event_pos: PersistedEventPosition, + event_entries: List[Tuple[_PendingRoomEventEntry, str]], max_room_stream_token: RoomStreamToken, - extra_users: Optional[Collection[UserID]] = None, ) -> None: """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -338,22 +334,33 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append( - _PendingRoomEventEntry( - event_pos=event_pos, - extra_users=extra_users or [], - room_id=room_id, - type=event_type, - state_key=state_key, - membership=membership, - ) - ) - self._notify_pending_new_room_events(max_room_stream_token) + for event_entry, event_id in event_entries: + self.pending_new_room_events.append(event_entry) + await self._third_party_rules.on_new_event(event_id) - await self._third_party_rules.on_new_event(event_id) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() + def create_pending_room_event_entry( + self, + event_pos: PersistedEventPosition, + extra_users: Optional[Collection[UserID]], + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + ) -> _PendingRoomEventEntry: + """Creates and returns a _PendingRoomEventEntry""" + return _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users or [], + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + def _notify_pending_new_room_events( self, max_room_stream_token: RoomStreamToken ) -> None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b2522f98ca..18252a2958 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -210,15 +210,16 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - await self.notifier.on_new_room_event_args( - event_pos=event_pos, - max_room_stream_token=max_token, - extra_users=extra_users, - room_id=row.data.room_id, - event_id=row.data.event_id, - event_type=row.data.type, - state_key=row.data.state_key, - membership=row.data.membership, + event_entry = self.notifier.create_pending_room_event_entry( + event_pos, + extra_users, + row.data.room_id, + row.data.type, + row.data.state_key, + row.data.membership, + ) + await self.notifier.notify_new_room_events( + [(event_entry, row.data.event_id)], max_token ) # If this event is a join, make a note of it so we have an accurate -- cgit 1.5.1 From 79c592cec68d66278e3233e2c9472f975942cfec Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 6 Oct 2022 12:22:36 +0200 Subject: Deprecate the `generate_short_term_login_token` method in favor of an async `create_login_token` method in the Module API. (#13842) Signed-off-by: Quentin Gliech Co-authored-by: Brendan Abolivier --- changelog.d/13842.removal | 1 + docs/upgrade.md | 33 +++++++++++++++++++++++++++++++++ synapse/module_api/__init__.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 changelog.d/13842.removal (limited to 'synapse') diff --git a/changelog.d/13842.removal b/changelog.d/13842.removal new file mode 100644 index 0000000000..cbcff38e91 --- /dev/null +++ b/changelog.d/13842.removal @@ -0,0 +1 @@ +Deprecate the `generate_short_term_login_token` method in favor of an async `create_login_token` method in the Module API. diff --git a/docs/upgrade.md b/docs/upgrade.md index 002ef70059..b81385b191 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -128,6 +128,39 @@ you may specify `enable_legacy_metrics: false` in your homeserver configuration. A list of affected metrics is available on the [Metrics How-to page](https://matrix-org.github.io/synapse/v1.69/metrics-howto.html?highlight=metrics%20deprecated#renaming-of-metrics--deprecation-of-old-names-in-12). +## Deprecation of the `generate_short_term_login_token` module API method + +The following method of the module API has been deprecated, and is scheduled to +be remove in v1.71.0: + +```python +def generate_short_term_login_token( + self, + user_id: str, + duration_in_ms: int = (2 * 60 * 1000), + auth_provider_id: str = "", + auth_provider_session_id: Optional[str] = None, +) -> str: + ... +``` + +It has been replaced by an asynchronous equivalent: + +```python +async def create_login_token( + self, + user_id: str, + duration_in_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, +) -> str: + ... +``` + +Synapse will log a warning when a module uses the deprecated method, to help +administrators find modules using it. + + # Upgrading to v1.68.0 Two changes announced in the upgrade notes for v1.67.0 have now landed in v1.68.0. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b7b2d3b8c5..6a6ae208d1 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -748,6 +748,40 @@ class ModuleApi: ) ) + async def create_login_token( + self, + user_id: str, + duration_in_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + """Create a login token suitable for m.login.token authentication + + Added in Synapse v1.69.0. + + Args: + user_id: gives the ID of the user that the token is for + + duration_in_ms: the time that the token will be valid for + + auth_provider_id: the ID of the SSO IdP that the user used to authenticate + to get this token, if any. This is encoded in the token so that + /login can report stats on number of successful logins by IdP. + + auth_provider_session_id: The session ID got during login from the SSO IdP, + if any. + """ + # The deprecated `generate_short_term_login_token` method defaulted to an empty + # string for the `auth_provider_id` because of how the underlying macaroon was + # generated. This will change to a proper NULL-able field when the tokens get + # moved to the database. + return self._hs.get_macaroon_generator().generate_short_term_login_token( + user_id, + auth_provider_id or "", + auth_provider_session_id, + duration_in_ms, + ) + def generate_short_term_login_token( self, user_id: str, @@ -759,6 +793,9 @@ class ModuleApi: Added in Synapse v1.9.0. + This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will + be removed in Synapse 1.71.0. + Args: user_id: gives the ID of the user that the token is for @@ -768,6 +805,11 @@ class ModuleApi: to get this token, if any. This is encoded in the token so that /login can report stats on number of successful logins by IdP. """ + logger.warn( + "A module configured on this server uses ModuleApi.generate_short_term_login_token(), " + "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in " + "Synapse 1.71.0", + ) return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, -- cgit 1.5.1 From e9a0419c8d28b8e153088073d6b76df6d7ed4ddf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 6 Oct 2022 14:00:03 +0100 Subject: Fix sending events into rooms with non-integer power levels (#14073) --- changelog.d/14073.misc | 1 + mypy.ini | 3 ++ synapse/push/bulk_push_rule_evaluator.py | 9 +++- tests/push/test_bulk_push_rule_evaluator.py | 74 +++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14073.misc create mode 100644 tests/push/test_bulk_push_rule_evaluator.py (limited to 'synapse') diff --git a/changelog.d/14073.misc b/changelog.d/14073.misc new file mode 100644 index 0000000000..7775500194 --- /dev/null +++ b/changelog.d/14073.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.68.0 where messages could not be sent in rooms with non-integer `notifications` power level. diff --git a/mypy.ini b/mypy.ini index 64f9097206..34b4523e00 100644 --- a/mypy.ini +++ b/mypy.ini @@ -106,6 +106,9 @@ disallow_untyped_defs = False [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.push.test_bulk_push_rule_evaluator] +disallow_untyped_defs = True + [mypy-tests.test_server] disallow_untyped_defs = True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..998354648f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -289,11 +289,18 @@ class BulkPushRuleEvaluator: if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + # It's possible that old room versions have non-integer power levels (floats or + # strings). Workaround this by explicitly converting to int. + notification_levels = power_levels.get("notifications", {}) + if not event.room_version.msc3667_int_only_power_levels: + for user_id, level in notification_levels.items(): + notification_levels[user_id] = int(level) + evaluator = PushRuleEvaluator( _flatten_dict(event), room_member_count, sender_power_level, - power_levels.get("notifications", {}), + notification_levels, relations, self._relations_match_enabled, ) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py new file mode 100644 index 0000000000..675d7df2ac --- /dev/null +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -0,0 +1,74 @@ +from unittest.mock import patch + +from synapse.api.room_versions import RoomVersions +from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.types import create_requester + +from tests import unittest + + +class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + register.register_servlets, + ] + + def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: + """We should convert floats and strings to integers before passing to Rust. + + Reproduces #14060. + + A lack of validation: the gift that keeps on giving. + """ + # Create a new user and room. + alice = self.register_user("alice", "pass") + token = self.login(alice, "pass") + + room_id = self.helper.create_room_as( + alice, room_version=RoomVersions.V9.identifier, tok=token + ) + + # Alter the power levels in that room to include stringy and floaty levels. + # We need to suppress the validation logic or else it will reject these dodgy + # values. (Presumably this validation was not always present.) + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(alice) + with patch("synapse.events.validator.validate_canonicaljson"), patch( + "synapse.events.validator.jsonschema.validate" + ): + self.helper.send_state( + room_id, + "m.room.power_levels", + { + "users": {alice: "100"}, # stringy + "notifications": {"room": 100.0}, # float + }, + token, + state_key="", + ) + + # Create a new message event, and try to evaluate it under the dodgy + # power level event. + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.message", + "room_id": room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": alice, + }, + ) + ) + + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + # should not raise + self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) -- cgit 1.5.1 From cb20b885cb4bd1648581dd043a184d86fc8c7a00 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 6 Oct 2022 19:17:50 +0100 Subject: Always close _all_ `ijson` coroutines, even if doing so raises Exceptions (#14065) --- changelog.d/14065.misc | 1 + synapse/federation/transport/client.py | 29 ++++++++++++++++++++---- synapse/util/__init__.py | 14 +++++++++++- tests/federation/transport/test_client.py | 37 +++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14065.misc (limited to 'synapse') diff --git a/changelog.d/14065.misc b/changelog.d/14065.misc new file mode 100644 index 0000000000..98998b0015 --- /dev/null +++ b/changelog.d/14065.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.35.0 where errors parsing a `/send_join` or `/state` response would produce excessive, low-quality Sentry events. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 32074b8ca6..cd39d4d111 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -45,6 +45,7 @@ from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser from synapse.http.types import QueryParams from synapse.types import JsonDict +from synapse.util import ExceptionBundle logger = logging.getLogger(__name__) @@ -926,8 +927,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]): return len(data) def finish(self) -> SendJoinResponse: - for c in self._coros: - c.close() + _close_coros(self._coros) if self._response.event_dict: self._response.event = make_event_from_dict( @@ -970,6 +970,27 @@ class _StateParser(ByteParser[StateRequestResponse]): return len(data) def finish(self) -> StateRequestResponse: - for c in self._coros: - c.close() + _close_coros(self._coros) return self._response + + +def _close_coros(coros: Iterable[Generator[None, bytes, None]]) -> None: + """Close each of the given coroutines. + + Always calls .close() on each coroutine, even if doing so raises an exception. + Any exceptions raised are aggregated into an ExceptionBundle. + + :raises ExceptionBundle: if at least one coroutine fails to close. + """ + exceptions = [] + for c in coros: + try: + c.close() + except Exception as e: + exceptions.append(e) + + if exceptions: + # raise from the first exception so that the traceback has slightly more context + raise ExceptionBundle( + f"There were {len(exceptions)} errors closing coroutines", exceptions + ) from exceptions[0] diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index a90f08dd4c..7be9d5f113 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,7 +15,7 @@ import json import logging import typing -from typing import Any, Callable, Dict, Generator, Optional +from typing import Any, Callable, Dict, Generator, Optional, Sequence import attr from frozendict import frozendict @@ -193,3 +193,15 @@ def log_failure( # Version string with git info. Computed here once so that we don't invoke git multiple # times. SYNAPSE_VERSION = get_distribution_version_string("matrix-synapse", __file__) + + +class ExceptionBundle(Exception): + # A poor stand-in for something like Python 3.11's ExceptionGroup. + # (A backport called `exceptiongroup` exists but seems overkill: we just want a + # container type here.) + def __init__(self, message: str, exceptions: Sequence[Exception]): + parts = [message] + for e in exceptions: + parts.append(str(e)) + super().__init__("\n - ".join(parts)) + self.exceptions = exceptions diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index c2320ce133..0926e0583d 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -13,6 +13,7 @@ # limitations under the License. import json +from unittest.mock import Mock from synapse.api.room_versions import RoomVersions from synapse.federation.transport.client import SendJoinParser @@ -94,3 +95,39 @@ class SendJoinParserTestCase(TestCase): # Retrieve and check the parsed SendJoinResponse parsed_response = parser.finish() self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) + + def test_errors_closing_coroutines(self) -> None: + """Check we close all coroutines, even if closing the first raises an Exception. + + We also check that an Exception of some kind is raised, but we don't make any + assertions about its attributes or type. + """ + parser = SendJoinParser(RoomVersions.V1, False) + response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + serialisation = json.dumps(response).encode() + + # Mock the coroutines managed by this parser. + # The first one will error when we try to close it. + coro_1 = Mock() + coro_1.close = Mock(side_effect=RuntimeError("Couldn't close coro 1")) + + coro_2 = Mock() + + coro_3 = Mock() + coro_3.close = Mock(side_effect=RuntimeError("Couldn't close coro 3")) + + parser._coros = [coro_1, coro_2, coro_3] + + # Send half of the data to the parser + parser.write(serialisation[: len(serialisation) // 2]) + + # Close the parser. There should be _some_ kind of exception, but it need not + # be that RuntimeError directly. E.g. we might want to raise a wrapper + # encompassing multiple errors from multiple coroutines. + with self.assertRaises(Exception): + parser.finish() + + # In any case, we should have tried to close both coros. + coro_1.close.assert_called() + coro_2.close.assert_called() + coro_3.close.assert_called() -- cgit 1.5.1 From 1fa2e58772620199075a36c237dd83cd989c0e91 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 7 Oct 2022 13:35:44 +0100 Subject: Catch BrokenPipeError from metrics server, and log as a warning (#14072) --- changelog.d/14072.misc | 1 + synapse/metrics/_legacy_exposition.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14072.misc (limited to 'synapse') diff --git a/changelog.d/14072.misc b/changelog.d/14072.misc new file mode 100644 index 0000000000..3070c756d5 --- /dev/null +++ b/changelog.d/14072.misc @@ -0,0 +1 @@ +Don't create noisy Sentry events when a requester drops connection to the metrics server mid-request. diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py index 563d8cc2c6..1459f9d224 100644 --- a/synapse/metrics/_legacy_exposition.py +++ b/synapse/metrics/_legacy_exposition.py @@ -20,7 +20,7 @@ Due to the renaming of metrics in prometheus_client 0.4.0, this customised vendoring of the code will emit both the old versions that Synapse dashboards expect, and the newer "best practice" version of the up-to-date official client. """ - +import logging import math import threading from http.server import BaseHTTPRequestHandler, HTTPServer @@ -34,6 +34,7 @@ from prometheus_client.core import Sample from twisted.web.resource import Resource from twisted.web.server import Request +logger = logging.getLogger(__name__) CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" @@ -219,11 +220,16 @@ class MetricsHandler(BaseHTTPRequestHandler): except Exception: self.send_error(500, "error generating metric output") raise - self.send_response(200) - self.send_header("Content-Type", CONTENT_TYPE_LATEST) - self.send_header("Content-Length", str(len(output))) - self.end_headers() - self.wfile.write(output) + try: + self.send_response(200) + self.send_header("Content-Type", CONTENT_TYPE_LATEST) + self.send_header("Content-Length", str(len(output))) + self.end_headers() + self.wfile.write(output) + except BrokenPipeError as e: + logger.warning( + "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e + ) def log_message(self, format: str, *args: Any) -> None: """Log nothing.""" -- cgit 1.5.1 From 2295095c97f3b4707f30ae8cb4562ebb799f7ac1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 7 Oct 2022 13:54:07 +0100 Subject: Use Pydantic to validate /devices endpoints (#14054) --- changelog.d/14054.feature | 1 + synapse/rest/client/devices.py | 98 ++++++++++++++++++++++-------------------- 2 files changed, 53 insertions(+), 46 deletions(-) create mode 100644 changelog.d/14054.feature (limited to 'synapse') diff --git a/changelog.d/14054.feature b/changelog.d/14054.feature new file mode 100644 index 0000000000..9cf3f7a557 --- /dev/null +++ b/changelog.d/14054.feature @@ -0,0 +1 @@ +Improve validation of request bodies for the [Device Management](https://spec.matrix.org/v1.4/client-server-api/#device-management) and [MSC2697 Device Dehyrdation](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) client-server API endpoints. diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index ed6ce78d47..90828c95c4 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,18 +14,21 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple + +from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, - assert_params_in_dict, - parse_json_object_from_request, + parse_and_validate_json_object_from_request, ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.rest.client.models import AuthenticationData +from synapse.rest.models import RequestBodyModel from synapse.types import JsonDict if TYPE_CHECKING: @@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + class PostBody(RequestBodyModel): + auth: Optional[AuthenticationData] + devices: List[StrictStr] + @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.PostBody) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: - # DELETE + # TODO: Can/should we remove this fallback now? # deal with older clients which didn't pass a JSON dict # the same as those that pass an empty dict - body = {} + body = self.PostBody.parse_obj({}) else: raise e - assert_params_in_dict(body, ["devices"]) - await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(exclude_unset=True), "remove device(s) from your account", # Users might call this multiple times in a row while cleaning up # devices, allow a single UI auth session to be re-used. @@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet): ) await self.device_handler.delete_devices( - requester.user.to_string(), body["devices"] + requester.user.to_string(), body.devices ) return 200, {} @@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet): return 200, device + class DeleteBody(RequestBodyModel): + auth: Optional[AuthenticationData] + @interactive_auth_handler async def on_DELETE( self, request: SynapseRequest, device_id: str @@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) try: - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.DeleteBody) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: + # TODO: can/should we remove this fallback now? # deal with older clients which didn't pass a JSON dict # the same as those that pass an empty dict - body = {} + body = self.DeleteBody.parse_obj({}) else: raise await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(exclude_unset=True), "remove a device from your account", # Users might call this multiple times in a row while cleaning up # devices, allow a single UI auth session to be re-used. @@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet): ) return 200, {} + class PutBody(RequestBodyModel): + display_name: Optional[StrictStr] + async def on_PUT( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.PutBody) await self.device_handler.update_device( - requester.user.to_string(), device_id, body + requester.user.to_string(), device_id, body.dict() ) return 200, {} +class DehydratedDeviceDataModel(RequestBodyModel): + """JSON blob describing a dehydrated device to be stored. + + Expects other freeform fields. Use .dict() to access them. + """ + + class Config: + extra = Extra.allow + + algorithm: StrictStr + + class DehydratedDeviceServlet(RestServlet): """Retrieve or store a dehydrated device. @@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet): else: raise errors.NotFoundError("No dehydrated device available") + class PutBody(RequestBodyModel): + device_id: StrictStr + device_data: DehydratedDeviceDataModel + initial_device_display_name: Optional[StrictStr] + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - submission = parse_json_object_from_request(request) + submission = parse_and_validate_json_object_from_request(request, self.PutBody) requester = await self.auth.get_user_by_req(request) - if "device_data" not in submission: - raise errors.SynapseError( - 400, - "device_data missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_data"], dict): - raise errors.SynapseError( - 400, - "device_data must be an object", - errcode=errors.Codes.INVALID_PARAM, - ) - device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission["device_data"], - submission.get("initial_device_display_name", None), + submission.device_data, + submission.initial_device_display_name, ) return 200, {"device_id": device_id} @@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + class PostBody(RequestBodyModel): + device_id: StrictStr + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - submission = parse_json_object_from_request(request) - - if "device_id" not in submission: - raise errors.SynapseError( - 400, - "device_id missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_id"], str): - raise errors.SynapseError( - 400, - "device_id must be a string", - errcode=errors.Codes.INVALID_PARAM, - ) + submission = parse_and_validate_json_object_from_request(request, self.PostBody) result = await self.device_handler.rehydrate_device( requester.user.to_string(), self.auth.get_access_token_from_request(request), - submission["device_id"], + submission.device_id, ) return 200, result -- cgit 1.5.1 From 66a785733458d0b5801097caff53624e202a91b4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Oct 2022 09:26:40 -0400 Subject: Use stable identifiers for MSC3771 & MSC3773. (#14050) These are both part of Matrix 1.4 which has now been released. For now, support both the unstable and stable identifiers. --- changelog.d/13776.feature | 2 +- changelog.d/13824.feature | 2 +- changelog.d/13877.feature | 2 +- changelog.d/13878.feature | 2 +- changelog.d/14050.feature | 1 + synapse/api/filtering.py | 13 +++++++---- synapse/config/experimental.py | 2 -- synapse/handlers/receipts.py | 11 ++++------ synapse/handlers/sync.py | 7 +----- synapse/rest/client/receipts.py | 48 ++++++++++++++++++++--------------------- synapse/rest/client/sync.py | 9 +++++--- synapse/rest/client/versions.py | 2 +- 12 files changed, 49 insertions(+), 52 deletions(-) create mode 100644 changelog.d/14050.feature (limited to 'synapse') diff --git a/changelog.d/13776.feature b/changelog.d/13776.feature index 22bce125ce..5d0ae16e13 100644 --- a/changelog.d/13776.feature +++ b/changelog.d/13776.feature @@ -1 +1 @@ -Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature index d0cb902dff..5d0ae16e13 100644 --- a/changelog.d/13824.feature +++ b/changelog.d/13824.feature @@ -1 +1 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/13877.feature b/changelog.d/13877.feature index d0cb902dff..5d0ae16e13 100644 --- a/changelog.d/13877.feature +++ b/changelog.d/13877.feature @@ -1 +1 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/13878.feature b/changelog.d/13878.feature index d0cb902dff..5d0ae16e13 100644 --- a/changelog.d/13878.feature +++ b/changelog.d/13878.feature @@ -1 +1 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/14050.feature b/changelog.d/14050.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14050.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index c6e44dcf82..cc31cf8cc7 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "unread_thread_notifications": {"type": "boolean"}, "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"}, # Include or exclude events with the provided labels. # cf https://github.com/matrix-org/matrix-doc/pull/2326 @@ -308,12 +309,16 @@ class Filter: self.include_redundant_members = filter_json.get( "include_redundant_members", False ) - if hs.config.experimental.msc3773_enabled: - self.unread_thread_notifications: bool = filter_json.get( + self.unread_thread_notifications: bool = filter_json.get( + "unread_thread_notifications", False + ) + if ( + not self.unread_thread_notifications + and hs.config.experimental.msc3773_enabled + ): + self.unread_thread_notifications = filter_json.get( "org.matrix.msc3773.unread_thread_notifications", False ) - else: - self.unread_thread_notifications = False self.types = filter_json.get("types", None) self.not_types = filter_json.get("not_types", []) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 6503ce6e34..c35301207a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3771: Thread read receipts - self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4768a34c07..4a7ec9e426 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,8 +63,6 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - self._msc3771_enabled = hs.config.experimental.msc3771_enabled - async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -96,11 +94,10 @@ class ReceiptsHandler: # Check if these receipts apply to a thread. thread_id = None data = user_values.get("data", {}) - if self._msc3771_enabled and isinstance(data, dict): - thread_id = data.get("thread_id") - # If the thread ID is invalid, consider it missing. - if not isinstance(thread_id, str): - thread_id = None + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None receipts.append( ReadReceipt( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0f684857ca..1db5d68021 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -279,8 +279,6 @@ class SyncHandler: self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync - self._msc3773_enabled = hs.config.experimental.msc3773_enabled - async def wait_for_sync_for_user( self, requester: Requester, @@ -2412,10 +2410,7 @@ class SyncHandler: unread_count = notifs.main_timeline.unread_count # Check the sync configuration. - if ( - self._msc3773_enabled - and sync_config.filter_collection.unread_thread_notifications() - ): + if sync_config.filter_collection.unread_thread_notifications(): # And add info for each thread. room_sync.unread_thread_notifications = { thread_id: { diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 287dfdd69e..14dec7ac4e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -50,7 +50,6 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } - self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -67,30 +66,29 @@ class ReceiptRestServlet(RestServlet): # Pull the thread ID, if one exists. thread_id = None - if self._msc3771_enabled: - if "thread_id" in body: - thread_id = body.get("thread_id") - if not thread_id or not isinstance(thread_id, str): - raise SynapseError( - 400, - "thread_id field must be a non-empty string", - Codes.INVALID_PARAM, - ) - - if receipt_type == ReceiptTypes.FULLY_READ: - raise SynapseError( - 400, - f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", - Codes.INVALID_PARAM, - ) - - # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): - raise SynapseError( - 400, - f"event_id {event_id} is not related to thread {thread_id}", - Codes.INVALID_PARAM, - ) + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, + "thread_id field must be a non-empty string", + Codes.INVALID_PARAM, + ) + + if receipt_type == ReceiptTypes.FULLY_READ: + raise SynapseError( + 400, + f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", + Codes.INVALID_PARAM, + ) + + # Ensure the event ID roughly correlates to the thread ID. + if thread_id != await self._main_store.get_thread_id(event_id): + raise SynapseError( + 400, + f"event_id {event_id} is not related to thread {thread_id}", + Codes.INVALID_PARAM, + ) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f1c23d68e5..8a16459105 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -100,6 +100,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() self._msc2654_enabled = hs.config.experimental.msc2654_enabled + self._msc3773_enabled = hs.config.experimental.msc3773_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. @@ -510,9 +511,11 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications if room.unread_thread_notifications: - result[ - "org.matrix.msc3773.unread_thread_notifications" - ] = room.unread_thread_notifications + result["unread_thread_notifications"] = room.unread_thread_notifications + if self._msc3773_enabled: + result[ + "org.matrix.msc3773.unread_thread_notifications" + ] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 18ed313b5c..d1d2e5f7e3 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,7 +105,7 @@ class VersionsRestServlet(RestServlet): # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. - "org.matrix.msc3771": self.config.experimental.msc3771_enabled, + "org.matrix.msc3771": True, "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, -- cgit 1.5.1 From 00c93d2e7ef5642c9cf900f3fdcfa229e70f843d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Oct 2022 09:29:43 -0400 Subject: Be more lenient in the oEmbed response parsing. (#14089) Attempt to parse any valid information from an oEmbed response (instead of bailing at the first unexpected data). This should allow for more partial oEmbed data to be returned, resulting in better / more URL previews, even if those URL previews are only partial. --- changelog.d/14089.bugfix | 1 + synapse/rest/media/v1/oembed.py | 107 ++++++++++++++++++++----------------- tests/rest/media/v1/test_oembed.py | 103 ++++++++++++++++++++++++++++++++++- 3 files changed, 160 insertions(+), 51 deletions(-) create mode 100644 changelog.d/14089.bugfix (limited to 'synapse') diff --git a/changelog.d/14089.bugfix b/changelog.d/14089.bugfix new file mode 100644 index 0000000000..4a398921bb --- /dev/null +++ b/changelog.d/14089.bugfix @@ -0,0 +1 @@ +Fix a bug where invalid oEmbed fields would cause the entire response to be discarded. Introduced in Synapse 1.18.0. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2177b46c9e..827afd868d 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -139,65 +139,72 @@ class OEmbedProvider: try: # oEmbed responses *must* be UTF-8 according to the spec. oembed = json_decoder.decode(raw_body.decode("utf-8")) + except ValueError: + return OEmbedResult({}, None, None) - # The version is a required string field, but not always provided, - # or sometimes provided as a float. Be lenient. - oembed_version = oembed.get("version", "1.0") - if oembed_version != "1.0" and oembed_version != 1: - raise RuntimeError(f"Invalid oEmbed version: {oembed_version}") + # The version is a required string field, but not always provided, + # or sometimes provided as a float. Be lenient. + oembed_version = oembed.get("version", "1.0") + if oembed_version != "1.0" and oembed_version != 1: + return OEmbedResult({}, None, None) - # Ensure the cache age is None or an int. - cache_age = oembed.get("cache_age") - if cache_age: - cache_age = int(cache_age) * 1000 - - # The results. - open_graph_response = { - "og:url": url, - } - - title = oembed.get("title") - if title: - open_graph_response["og:title"] = title - - author_name = oembed.get("author_name") + # Attempt to parse the cache age, if possible. + try: + cache_age = int(oembed.get("cache_age")) * 1000 + except (TypeError, ValueError): + # If the cache age cannot be parsed (e.g. wrong type or invalid + # string), ignore it. + cache_age = None - # Use the provider name and as the site. - provider_name = oembed.get("provider_name") - if provider_name: - open_graph_response["og:site_name"] = provider_name + # The oEmbed response converted to Open Graph. + open_graph_response: JsonDict = {"og:url": url} - # If a thumbnail exists, use it. Note that dimensions will be calculated later. - if "thumbnail_url" in oembed: - open_graph_response["og:image"] = oembed["thumbnail_url"] + title = oembed.get("title") + if title and isinstance(title, str): + open_graph_response["og:title"] = title - # Process each type separately. - oembed_type = oembed["type"] - if oembed_type == "rich": - calc_description_and_urls(open_graph_response, oembed["html"]) - - elif oembed_type == "photo": - # If this is a photo, use the full image, not the thumbnail. - open_graph_response["og:image"] = oembed["url"] + author_name = oembed.get("author_name") + if not isinstance(author_name, str): + author_name = None - elif oembed_type == "video": - open_graph_response["og:type"] = "video.other" + # Use the provider name and as the site. + provider_name = oembed.get("provider_name") + if provider_name and isinstance(provider_name, str): + open_graph_response["og:site_name"] = provider_name + + # If a thumbnail exists, use it. Note that dimensions will be calculated later. + thumbnail_url = oembed.get("thumbnail_url") + if thumbnail_url and isinstance(thumbnail_url, str): + open_graph_response["og:image"] = thumbnail_url + + # Process each type separately. + oembed_type = oembed.get("type") + if oembed_type == "rich": + html = oembed.get("html") + if isinstance(html, str): + calc_description_and_urls(open_graph_response, html) + + elif oembed_type == "photo": + # If this is a photo, use the full image, not the thumbnail. + url = oembed.get("url") + if url and isinstance(url, str): + open_graph_response["og:image"] = url + + elif oembed_type == "video": + open_graph_response["og:type"] = "video.other" + html = oembed.get("html") + if html and isinstance(html, str): calc_description_and_urls(open_graph_response, oembed["html"]) - open_graph_response["og:video:width"] = oembed["width"] - open_graph_response["og:video:height"] = oembed["height"] - - elif oembed_type == "link": - open_graph_response["og:type"] = "website" + for size in ("width", "height"): + val = oembed.get(size) + if val is not None and isinstance(val, int): + open_graph_response[f"og:video:{size}"] = val - else: - raise RuntimeError(f"Unknown oEmbed type: {oembed_type}") + elif oembed_type == "link": + open_graph_response["og:type"] = "website" - except Exception as e: - # Trap any exception and let the code follow as usual. - logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) - open_graph_response = {} - author_name = None - cache_age = None + else: + logger.warning("Unknown oEmbed type: %s", oembed_type) return OEmbedResult(open_graph_response, author_name, cache_age) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index f38d7225f8..319ae8b1cc 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -14,6 +14,8 @@ import json +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult @@ -23,8 +25,16 @@ from synapse.util import Clock from tests.unittest import HomeserverTestCase +try: + import lxml +except ImportError: + lxml = None + class OEmbedTests(HomeserverTestCase): + if not lxml: + skip = "url preview feature requires lxml" + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.oembed = OEmbedProvider(hs) @@ -36,7 +46,7 @@ class OEmbedTests(HomeserverTestCase): def test_version(self) -> None: """Accept versions that are similar to 1.0 as a string or int (or missing).""" for version in ("1.0", 1.0, 1): - result = self.parse_response({"version": version, "type": "link"}) + result = self.parse_response({"version": version}) # An empty Open Graph response is an error, ensure the URL is included. self.assertIn("og:url", result.open_graph_result) @@ -49,3 +59,94 @@ class OEmbedTests(HomeserverTestCase): result = self.parse_response({"version": version, "type": "link"}) # An empty Open Graph response is an error, ensure the URL is included. self.assertEqual({}, result.open_graph_result) + + def test_cache_age(self) -> None: + """Ensure a cache-age is parsed properly.""" + # Correct-ish cache ages are allowed. + for cache_age in ("1", 1.0, 1): + result = self.parse_response({"cache_age": cache_age}) + self.assertEqual(result.cache_age, 1000) + + # Invalid cache ages are ignored. + for cache_age in ("invalid", {}): + result = self.parse_response({"cache_age": cache_age}) + self.assertIsNone(result.cache_age) + + # Cache age is optional. + result = self.parse_response({}) + self.assertIsNone(result.cache_age) + + @parameterized.expand( + [ + ("title", "title"), + ("provider_name", "site_name"), + ("thumbnail_url", "image"), + ], + name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}", + ) + def test_property(self, oembed_property: str, open_graph_property: str) -> None: + """Test properties which must be strings.""" + result = self.parse_response({oembed_property: "test"}) + self.assertIn(f"og:{open_graph_property}", result.open_graph_result) + self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test") + + result = self.parse_response({oembed_property: 1}) + self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result) + + def test_author_name(self) -> None: + """Test the author_name property.""" + result = self.parse_response({"author_name": "test"}) + self.assertEqual(result.author_name, "test") + + result = self.parse_response({"author_name": 1}) + self.assertIsNone(result.author_name) + + def test_rich(self) -> None: + """Test a type of rich.""" + result = self.parse_response({"html": "test", "type": "rich"}) + self.assertIn("og:description", result.open_graph_result) + self.assertIn("og:image", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:description"], "test") + self.assertEqual(result.open_graph_result["og:image"], "foo") + + result = self.parse_response({"type": "rich"}) + self.assertNotIn("og:description", result.open_graph_result) + + result = self.parse_response({"html": 1, "type": "rich"}) + self.assertNotIn("og:description", result.open_graph_result) + + def test_photo(self) -> None: + """Test a type of photo.""" + result = self.parse_response({"url": "test", "type": "photo"}) + self.assertIn("og:image", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:image"], "test") + + result = self.parse_response({"type": "photo"}) + self.assertNotIn("og:image", result.open_graph_result) + + result = self.parse_response({"url": 1, "type": "photo"}) + self.assertNotIn("og:image", result.open_graph_result) + + def test_video(self) -> None: + """Test a type of video.""" + result = self.parse_response({"html": "test", "type": "video"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "video.other") + self.assertIn("og:description", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:description"], "test") + + result = self.parse_response({"type": "video"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "video.other") + self.assertNotIn("og:description", result.open_graph_result) + + result = self.parse_response({"url": 1, "type": "video"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "video.other") + self.assertNotIn("og:description", result.open_graph_result) + + def test_link(self) -> None: + """Test type of link.""" + result = self.parse_response({"type": "link"}) + self.assertIn("og:type", result.open_graph_result) + self.assertEqual(result.open_graph_result["og:type"], "website") -- cgit 1.5.1 From f1673866ed8a39d49e2caaa6f4255a3f696bc3b4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 7 Oct 2022 15:15:10 +0100 Subject: Unpin build-system requirements, but impose an upper-bound (#14085) * Revert to prior build-system requirements This reverts #14080. * Use normalised extra name, which poetry-core 1.3 will generate anyway * Changelog * Upper bound build-system requirements * Remove upgrade note; expand changelog entry a little. * Fix typo in build-system comment Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/14085.misc | 1 + pyproject.toml | 11 ++++++++--- synapse/config/repository.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14085.misc (limited to 'synapse') diff --git a/changelog.d/14085.misc b/changelog.d/14085.misc new file mode 100644 index 0000000000..2d2df70a64 --- /dev/null +++ b/changelog.d/14085.misc @@ -0,0 +1 @@ +Rename the `url_preview` extra to `url-preview`, for compatability with poetry-core 1.3.0 and [PEP 685](https://peps.python.org/pep-0685/). From-source installations using this extra will need to install using the new name. diff --git a/pyproject.toml b/pyproject.toml index 622d6a9e89..81b2659eb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,7 +219,7 @@ oidc = ["authlib"] # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. systemd = ["systemd-python"] -url_preview = ["lxml"] +url-preview = ["lxml"] sentry = ["sentry-sdk"] opentracing = ["jaeger-client", "opentracing"] jwt = ["authlib"] @@ -250,7 +250,7 @@ all = [ "pysaml2", # oidc and jwt "authlib", - # url_preview + # url-preview "lxml", # sentry "sentry-sdk", @@ -307,7 +307,12 @@ twine = "*" towncrier = ">=18.6.0rc1" [build-system] -requires = ["poetry-core==1.2.0", "setuptools_rust==1.5.2"] +# The upper bounds here are defensive, intended to prevent situations like +# #13849 and #14079 where we see buildtime or runtime errors caused by build +# system changes. +# We are happy to raise these upper bounds upon request, +# provided we check that it's safe to do so (i.e. that CI passes). +requires = ["poetry-core>=1.0.0,<=1.3.1", "setuptools_rust>=1.3,<=1.5.2"] build-backend = "poetry.core.masonry.api" diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 1033496bb4..e4759711ed 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -205,7 +205,7 @@ class ContentRepositoryConfig(Config): ) self.url_preview_enabled = config.get("url_preview_enabled", False) if self.url_preview_enabled: - check_requirements("url_preview") + check_requirements("url-preview") proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: -- cgit 1.5.1 From dc37b68a25754240243cdca6f521919abfe71db0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 7 Oct 2022 16:19:59 +0200 Subject: Parse SYNAPSE_ASYNC_IO_REACTOR env variable & log the reactor on startup (#14092) --- changelog.d/14092.misc | 1 + synapse/__init__.py | 26 ++++++++++++-------------- synapse/config/logger.py | 3 +++ 3 files changed, 16 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14092.misc (limited to 'synapse') diff --git a/changelog.d/14092.misc b/changelog.d/14092.misc new file mode 100644 index 0000000000..c48f40cd38 --- /dev/null +++ b/changelog.d/14092.misc @@ -0,0 +1 @@ +Run the integration test suites with the asyncio reactor enabled in CI. diff --git a/synapse/__init__.py b/synapse/__init__.py index 1bed6393bd..fbfd506a43 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -21,6 +21,7 @@ import os import sys from synapse.util.rust import check_rust_lib_up_to_date +from synapse.util.stringutils import strtobool # Check that we're not running on an unsupported Python version. if sys.version_info < (3, 7): @@ -28,25 +29,22 @@ if sys.version_info < (3, 7): sys.exit(1) # Allow using the asyncio reactor via env var. -if bool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", False)): - try: - from incremental import Version +if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")): + from incremental import Version - import twisted + import twisted - # We need a bugfix that is included in Twisted 21.2.0: - # https://twistedmatrix.com/trac/ticket/9787 - if twisted.version < Version("Twisted", 21, 2, 0): - print("Using asyncio reactor requires Twisted>=21.2.0") - sys.exit(1) + # We need a bugfix that is included in Twisted 21.2.0: + # https://twistedmatrix.com/trac/ticket/9787 + if twisted.version < Version("Twisted", 21, 2, 0): + print("Using asyncio reactor requires Twisted>=21.2.0") + sys.exit(1) - import asyncio + import asyncio - from twisted.internet import asyncioreactor + from twisted.internet import asyncioreactor - asyncioreactor.install(asyncio.get_event_loop()) - except ImportError: - pass + asyncioreactor.install(asyncio.get_event_loop()) # Twisted and canonicaljson will fail to import when this file is executed to # get the __version__ during a fresh install. That's OK and subsequent calls to diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 6c1f78f8df..b62b3b9205 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -326,6 +326,8 @@ def setup_logging( logBeginner: The Twisted logBeginner to use. """ + from twisted.internet import reactor + log_config_path = ( config.worker.worker_log_config if use_worker_options @@ -348,3 +350,4 @@ def setup_logging( ) logging.info("Server hostname: %s", config.server.server_name) logging.info("Instance name: %s", hs.get_instance_name()) + logging.info("Twisted reactor: %s", type(reactor).__name__) -- cgit 1.5.1 From ab8047b4bf581d0c343c1e900e8740745668d941 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Oct 2022 11:27:50 -0400 Subject: Apply & bundle edits for non-message events. (#14034) Fixes two related bugs: * No edit information was bundled for events which aren't `m.room.message`. * `m.new_content` was not applied for those events. --- changelog.d/14034.bugfix | 1 + synapse/storage/databases/main/relations.py | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14034.bugfix (limited to 'synapse') diff --git a/changelog.d/14034.bugfix b/changelog.d/14034.bugfix new file mode 100644 index 0000000000..e437ef3a01 --- /dev/null +++ b/changelog.d/14034.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where edits of non-`m.room.message` events would not be correctly bundled or have their new content applied. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 154385b1e8..116abef9de 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -384,12 +384,11 @@ class RelationsWorkerStore(SQLBaseStore): the event will map to None. """ - # We only allow edits for `m.room.message` events that have the same sender - # and event type. We can't assert these things during regular event auth so - # we have to do the checks post hoc. + # We only allow edits for events that have the same sender and event type. + # We can't assert these things during regular event auth so we have to do + # the checks post hoc. - # Fetches latest edit that has the same type and sender as the - # original, and is an `m.room.message`. + # Fetches latest edit that has the same type and sender as the original. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, # so ordering by origin server ts + event ID desc will ensure we get @@ -405,7 +404,6 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? - AND edit.type = 'm.room.message' ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC """ else: @@ -424,7 +422,6 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? - AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts, edit.event_id """ -- cgit 1.5.1 From e03d7c5fd0577df5b62cd34559925c6cfe3e0360 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Oct 2022 12:38:46 -0400 Subject: Remove support for the unstable dir flag on relations. (#14106) From MSC3715, this was unused by clients (and there was no way for clients to know it was supported). Matrix 1.4 defines the stable field. --- changelog.d/14106.removal | 1 + synapse/config/experimental.py | 3 --- synapse/handlers/relations.py | 33 ++++++++++++++--------------- synapse/rest/client/relations.py | 45 +++++++++------------------------------- synapse/streams/config.py | 6 ++++-- 5 files changed, 31 insertions(+), 57 deletions(-) create mode 100644 changelog.d/14106.removal (limited to 'synapse') diff --git a/changelog.d/14106.removal b/changelog.d/14106.removal new file mode 100644 index 0000000000..08fa752897 --- /dev/null +++ b/changelog.d/14106.removal @@ -0,0 +1 @@ +Remove the unstable identifier for [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c35301207a..e00cb7096c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -100,9 +100,6 @@ class ExperimentalConfig(Config): # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) - # MSC3715: dir param on /relations. - self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) - # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 63bc6a7aa5..cc5e45c241 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -21,6 +21,7 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -72,13 +73,10 @@ class RelationsHandler: requester: Requester, event_id: str, room_id: str, + pagin_config: PaginationConfig, + include_original_event: bool, relation_type: Optional[str] = None, event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[StreamToken] = None, - to_token: Optional[StreamToken] = None, - include_original_event: bool = False, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. @@ -88,14 +86,10 @@ class RelationsHandler: requester: The user requesting the relations. event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. + pagin_config: The pagination config rules to apply, if any. + include_original_event: Whether to include the parent event. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. - limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. - include_original_event: Whether to include the parent event. Returns: The pagination chunk. @@ -114,6 +108,9 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") + # TODO Update pagination config to not allow None limits. + assert pagin_config.limit is not None + # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). @@ -123,10 +120,10 @@ class RelationsHandler: room_id=room_id, relation_type=relation_type, event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, + limit=pagin_config.limit, + direction=pagin_config.direction, + from_token=pagin_config.from_token, + to_token=pagin_config.to_token, ) events = await self._main_store.get_events_as_list( @@ -162,8 +159,10 @@ class RelationsHandler: if next_token: return_value["next_batch"] = await next_token.to_string(self._main_store) - if from_token: - return_value["prev_batch"] = await from_token.to_string(self._main_store) + if pagin_config.from_token: + return_value["prev_batch"] = await pagin_config.from_token.to_string( + self._main_store + ) return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 7a25de5c85..b31ce5a0d3 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -16,10 +16,11 @@ import logging from typing import TYPE_CHECKING, Optional, Tuple from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, StreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -41,9 +42,8 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._store = hs.get_datastores().main self._relations_handler = hs.get_relations_handler() - self._msc3715_enabled = hs.config.experimental.msc3715_enabled async def on_GET( self, @@ -55,49 +55,24 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - limit = parse_integer(request, "limit", default=5) - # Fetch the direction parameter, if provided. - # - # TODO Use PaginationConfig.from_request when the unstable parameter is - # no longer needed. - direction = parse_string(request, "dir", allowed_values=["f", "b"]) - if direction is None: - if self._msc3715_enabled: - direction = parse_string( - request, - "org.matrix.msc3715.dir", - default="b", - allowed_values=["f", "b"], - ) - else: - direction = "b" - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) + pagination_config = await PaginationConfig.from_request( + self._store, request, default_limit=5, default_dir="b" + ) # The unstable version of this API returns an extra field for client # compatibility, see https://github.com/matrix-org/synapse/issues/12930. assert request.path is not None include_original_event = request.path.startswith(b"/_matrix/client/unstable/") + # Return the relations result = await self._relations_handler.get_relations( requester=requester, event_id=parent_id, room_id=room_id, + pagin_config=pagination_config, + include_original_event=include_original_event, relation_type=relation_type, event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, - include_original_event=include_original_event, ) return 200, result diff --git a/synapse/streams/config.py b/synapse/streams/config.py index b52723e2b8..f6f7bf3d8b 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -42,10 +42,12 @@ class PaginationConfig: cls, store: "DataStore", request: SynapseRequest, - raise_invalid_params: bool = True, default_limit: Optional[int] = None, + default_dir: str = "f", ) -> "PaginationConfig": - direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + direction = parse_string( + request, "dir", default=default_dir, allowed_values=["f", "b"] + ) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") -- cgit 1.5.1 From 1bf2832714abdfc5e10395e8e76aecc591ad265f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 7 Oct 2022 11:39:45 -0500 Subject: Indicate what endpoint came back with a JSON response we were unable to parse (#14097) **Before:** ``` WARNING - POST-11 - Unable to parse JSON: Expecting value: line 1 column 1 (char 0) (b'') ``` **After:** ``` WARNING - POST-11 - Unable to parse JSON from POST /_matrix/client/v3/join/%21ZlmJtelqFroDRJYZaq:hs1?server_name=hs1 response: Expecting value: line 1 column 1 (char 0) (b'') ``` --- It's possible to figure out which endpoint these warnings were coming from before but you had to follow the request ID `POST-11` to the log line that says `Completed request [...]`. Including this key information next to the JSON parsing error makes it much easier to reason whether it matters or not. ``` 2022-09-29T08:23:25.7875506Z synapse_main | 2022-09-29 08:21:10,336 - synapse.http.matrixfederationclient - 299 - INFO - POST-11 - {GET-O-13} [hs1] Completed request: 200 OK in 0.53 secs, got 450 bytes - GET matrix://hs1/_matrix/federation/v1/make_join/%21ohtKoQiXlPePSycXwp%3Ahs1/%40charlie%3Ahs2?ver=1&ver=2&ver=3&ver=4&ver=5&ver=6&ver=org.matrix.msc2176&ver=7&ver=8&ver=9&ver=org.matrix.msc3787&ver=10&ver=org.matrix.msc2716v4 ``` --- As a note, having no `body` is normal for the `/join` endpoint and it can handle it. https://github.com/matrix-org/synapse/blob/0c853e09709d52783efd37060ed9e8f55a4fc704/synapse/rest/client/room.py#L398-L403 Alternatively we could remove these extra logs but they are probably more usually helpful to figure out what went wrong. --- changelog.d/14097.misc | 1 + synapse/http/servlet.py | 9 ++++++++- tests/http/test_servlet.py | 4 +++- 3 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14097.misc (limited to 'synapse') diff --git a/changelog.d/14097.misc b/changelog.d/14097.misc new file mode 100644 index 0000000000..8392448c4d --- /dev/null +++ b/changelog.d/14097.misc @@ -0,0 +1 @@ +Indicate what endpoint came back with a JSON response we were unable to parse. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 80acbdcf3c..dead02cd5c 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -35,6 +35,7 @@ from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.http import redact_uri from synapse.http.server import HttpServer from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder @@ -664,7 +665,13 @@ def parse_json_value_from_request( try: content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: - logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) + logger.warning( + "Unable to parse JSON from %s %s response: %s (%s)", + request.method.decode("ascii", errors="replace"), + redact_uri(request.uri.decode("ascii", errors="replace")), + e, + content_bytes, + ) raise SynapseError( HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON ) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 3cbca0f5a3..46166292fe 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -35,11 +35,13 @@ from tests.http.server._base import test_disconnect def make_request(content): """Make an object that acts enough like a request.""" - request = Mock(spec=["content"]) + request = Mock(spec=["method", "uri", "content"]) if isinstance(content, dict): content = json.dumps(content).encode("utf8") + request.method = bytes("STUB_METHOD", "ascii") + request.uri = bytes("/test_stub_uri", "ascii") request.content = BytesIO(content) return request -- cgit 1.5.1 From 422cff7df6df3ac3691829fbce3fbd486f399869 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 11 Oct 2022 14:41:06 +0200 Subject: Fallback if 'approved' isn't included in a registration replication request (#14135) --- changelog.d/14135.bugfix | 1 + synapse/replication/http/register.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14135.bugfix (limited to 'synapse') diff --git a/changelog.d/14135.bugfix b/changelog.d/14135.bugfix new file mode 100644 index 0000000000..6d1d7816e8 --- /dev/null +++ b/changelog.d/14135.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.69.0rc1 which would cause registration replication requests to fail if the worker sending the request is not running Synapse 1.69. diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 61abb529c8..976c283360 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -39,6 +39,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint): self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() + # Default value if the worker that sent the replication request did not include + # an 'approved' property. + if ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ): + self._approval_default = False + else: + self._approval_default = True + @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, @@ -92,6 +102,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): await self.registration_handler.check_registration_ratelimit(content["address"]) + # Always default admin users to approved (since it means they were created by + # an admin). + approved_default = self._approval_default + if content["admin"]: + approved_default = True + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], @@ -103,7 +119,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type=content["user_type"], address=content["address"], shadow_banned=content["shadow_banned"], - approved=content["approved"], + approved=content.get("approved", approved_default), ) return 200, {} -- cgit 1.5.1 From a9934d48c193bc963e3d232ed83c5cbfa3e5152d Mon Sep 17 00:00:00 2001 From: Abdullah Osama Date: Tue, 11 Oct 2022 14:42:11 +0200 Subject: Making parse_server_name more consistent (#14007) Fixes #12122 --- changelog.d/14007.misc | 1 + synapse/util/stringutils.py | 4 ++-- tests/http/test_endpoint.py | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14007.misc (limited to 'synapse') diff --git a/changelog.d/14007.misc b/changelog.d/14007.misc new file mode 100644 index 0000000000..3f0f3afe1c --- /dev/null +++ b/changelog.d/14007.misc @@ -0,0 +1 @@ +Make `parse_server_name` consistent in handling invalid server names. \ No newline at end of file diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 27a363d7e5..4961fe9313 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -86,7 +86,7 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: ValueError if the server name could not be parsed. """ try: - if server_name[-1] == "]": + if server_name and server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -123,7 +123,7 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int] # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == "[": + if host and host[0] == "[": if host[-1] != "]": raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index c8cc21cadd..a801f002a0 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -25,6 +25,8 @@ class ServerNameTestCase(unittest.TestCase): "[0abc:1def::1234]": ("[0abc:1def::1234]", None), "1.2.3.4:1": ("1.2.3.4", 1), "[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080), + ":80": ("", 80), + "": ("", None), } for i, o in test_data.items(): @@ -42,6 +44,7 @@ class ServerNameTestCase(unittest.TestCase): "newline.com\n", ".empty-label.com", "1234:5678:80", # too many colons + ":80", ] for i in test_data: try: -- cgit 1.5.1 From 02086e1da0e3fa3d5002bf2eb7560c043ad47187 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Oct 2022 16:13:32 +0100 Subject: Fix rotating existing notifications in push summary (#14138) Broke by #14045. Fixes #14120. Introduced in v1.69.0rc2. --- changelog.d/14138.bugfix | 1 + synapse/storage/databases/main/event_push_actions.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14138.bugfix (limited to 'synapse') diff --git a/changelog.d/14138.bugfix b/changelog.d/14138.bugfix new file mode 100644 index 0000000000..e2a2f3509e --- /dev/null +++ b/changelog.d/14138.bugfix @@ -0,0 +1 @@ +Fix error in background update when rotating existing notifications. Introduced in v1.69.0rc2. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index c9724d7345..87d07f7d9b 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1104,11 +1104,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) # First ensure that the existing rows have an updated thread_id field. - self.db_pool.simple_update_txn( - txn, - table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id, "thread_id": None}, - updatevalues={"thread_id": "main"}, + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + ("main", room_id, user_id), ) # Replace the previous summary with the new counts. @@ -1272,6 +1274,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) # Ensure that any updated threads have an updated thread_id. + txn.execute_batch( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + [("main", room_id, user_id) for user_id, room_id in summaries], + ) self.db_pool.simple_update_many_txn( txn, table="event_push_summary", -- cgit 1.5.1 From 6136768e766b4b545d1e0e8ee6e18862292509f3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Oct 2022 13:14:57 -0400 Subject: Remove the groups config code. (#14142) This has been unused for a long time, but missed removal in #11584. --- changelog.d/14142.misc | 1 + synapse/config/groups.py | 27 --------------------------- 2 files changed, 1 insertion(+), 27 deletions(-) create mode 100644 changelog.d/14142.misc delete mode 100644 synapse/config/groups.py (limited to 'synapse') diff --git a/changelog.d/14142.misc b/changelog.d/14142.misc new file mode 100644 index 0000000000..3649317013 --- /dev/null +++ b/changelog.d/14142.misc @@ -0,0 +1 @@ +Remove unused configuration code. diff --git a/synapse/config/groups.py b/synapse/config/groups.py deleted file mode 100644 index baa051fdd4..0000000000 --- a/synapse/config/groups.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from synapse.types import JsonDict - -from ._base import Config - - -class GroupsConfig(Config): - section = "groups" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.enable_group_creation = config.get("enable_group_creation", False) - self.group_creation_prefix = config.get("group_creation_prefix", "") -- cgit 1.5.1 From a86b2f6837f0a067b0a014fbf5140e8773b8da2e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 11 Oct 2022 11:18:45 -0700 Subject: Fix a bug where redactions were not being sent over federation if we did not have the original event. (#13813) --- changelog.d/13813.bugfix | 1 + synapse/federation/sender/__init__.py | 29 +++++++++++++++++-------- synapse/handlers/appservice.py | 9 +++++--- synapse/storage/databases/main/events_worker.py | 15 +++++++++---- synapse/storage/databases/main/stream.py | 28 +++++++++++------------- tests/handlers/test_appservice.py | 18 +++++++++------ 6 files changed, 62 insertions(+), 38 deletions(-) create mode 100644 changelog.d/13813.bugfix (limited to 'synapse') diff --git a/changelog.d/13813.bugfix b/changelog.d/13813.bugfix new file mode 100644 index 0000000000..23388788ff --- /dev/null +++ b/changelog.d/13813.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where redactions were not being sent over federation if we did not have the original event. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index a6cb3ba58f..774ecd81b6 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender): last_token = await self.store.get_federation_out_pos("events") ( next_token, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( + ) = await self.store.get_all_new_event_ids_stream( last_token, self._last_poked_id, limit=100 ) + event_ids = event_to_received_ts.keys() + event_entries = await self.store.get_unredacted_events_from_cache_or_db( + event_ids + ) + logger.debug( "Handling %i -> %i: %i events to send (current id %i)", last_token, next_token, - len(events), + len(event_entries), self._last_poked_id, ) - if not events and next_token >= self._last_poked_id: + if not event_entries and next_token >= self._last_poked_id: logger.debug("All events processed") break @@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender): await handle_event(event) events_by_room: Dict[str, List[EventBase]] = {} - for event in events: - events_by_room.setdefault(event.room_id, []).append(event) + + for event_id in event_ids: + # `event_entries` is unsorted, so we have to iterate over `event_ids` + # to ensure the events are in the right order + event_cache = event_entries.get(event_id) + if event_cache: + event = event_cache.event + events_by_room.setdefault(event.room_id, []).append(event) await make_deferred_yieldable( defer.gatherResults( @@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender): logger.debug("Successfully handled up to %i", next_token) await self.store.update_federation_out_pos("events", next_token) - if events: + if event_entries: now = self.clock.time_msec() - ts = event_to_received_ts[events[-1].event_id] + last_id = next(reversed(event_ids)) + ts = event_to_received_ts[last_id] assert ts is not None synapse.metrics.event_processing_lag.labels( @@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).set(ts) - events_processed_counter.inc(len(events)) + events_processed_counter.inc(len(event_entries)) event_processing_loop_room_count.labels("federation_sender").inc( len(events_by_room) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 203b62e015..66f5b8d108 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -109,10 +109,13 @@ class ApplicationServicesHandler: last_token = await self.store.get_appservice_last_pos() ( upper_bound, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( - last_token, self.current_max, limit=100, get_prev_content=True + ) = await self.store.get_all_new_event_ids_stream( + last_token, self.current_max, limit=100 + ) + + events = await self.store.get_events_as_list( + event_to_received_ts.keys(), get_prev_content=True ) events_by_room: Dict[str, List[EventBase]] = {} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7cdc9fe98f..d4104462b5 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = await self._get_events_from_cache_or_db( + event_entry_map = await self.get_unredacted_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = await self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self.get_unredacted_events_from_cache_or_db( + [redacted_event_id] + ) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore): return events @cancellable - async def _get_events_from_cache_or_db( - self, event_ids: Iterable[str], allow_rejected: bool = False + async def get_unredacted_events_from_cache_or_db( + self, + event_ids: Iterable[str], + allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. + Note that the events pulled by this function will not have any redactions + applied, and no guarantee is made about the ordering of the events returned. + If events are pulled from the database, they will be cached for future lookups. Unknown events are omitted from the response. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 530f04e149..ffeb2b3683 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1024,28 +1024,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False - ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: + async def get_all_new_event_ids_stream( + self, + from_id: int, + current_id: int, + limit: int, + ) -> Tuple[int, Dict[str, Optional[int]]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all event ids with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return - get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events, event_to_received_ts), where `next_id` + A tuple of (next_id, event_to_received_ts), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). The `event_to_received_ts` is - a dictionary mapping event ID to the event `received_ts`. + a dictionary mapping event ID to the event `received_ts`, sorted by ascending + stream_ordering. """ - def get_all_new_events_stream_txn( + def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( @@ -1070,15 +1073,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, event_to_received_ts upper_bound, event_to_received_ts = await self.db_pool.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = await self.get_events_as_list( - event_to_received_ts.keys(), - get_prev_content=get_prev_content, + "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn ) - return upper_bound, events, event_to_received_ts + return upper_bound, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index af24c4984d..7e4570f990 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_all_new_events_stream.side_effect = [ - make_awaitable((0, [], {})), - make_awaitable((1, [event], {event.event_id: 0})), + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {})), + make_awaitable((1, {event.event_id: 0})), + ] + self.mock_store.get_events_as_list.side_effect = [ + make_awaitable([]), + make_awaitable([event]), ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) @@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_events_stream.side_effect = [ - make_awaitable((0, [event], {event.event_id: 0})), + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {event.event_id: 0})), ] - + self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_events_stream.side_effect = [ + self.mock_store.get_all_new_event_ids_stream.side_effect = [ make_awaitable((0, [event], {event.event_id: 0})), ] -- cgit 1.5.1 From 09be8ab5f9d54fa1a577d8b0028abf8acc28f30d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 06:26:39 -0400 Subject: Remove the experimental implementation of MSC3772. (#14094) MSC3772 has been abandoned. --- changelog.d/14094.removal | 1 + rust/src/push/base_rules.rs | 13 ---- rust/src/push/evaluator.rs | 105 +--------------------------- rust/src/push/mod.rs | 44 +++--------- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 2 - synapse/push/bulk_push_rule_evaluator.py | 64 +---------------- synapse/storage/databases/main/cache.py | 3 - synapse/storage/databases/main/events.py | 5 -- synapse/storage/databases/main/push_rule.py | 15 ++-- synapse/storage/databases/main/relations.py | 53 -------------- tests/push/test_push_rule_evaluator.py | 76 +------------------- 12 files changed, 22 insertions(+), 365 deletions(-) create mode 100644 changelog.d/14094.removal (limited to 'synapse') diff --git a/changelog.d/14094.removal b/changelog.d/14094.removal new file mode 100644 index 0000000000..6ef03b1a0f --- /dev/null +++ b/changelog.d/14094.removal @@ -0,0 +1 @@ +Remove the experimental implementation of [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772). diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 2a09cf99ae..63240cacfc 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -257,19 +257,6 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, - PushRule { - rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), - priority_class: 1, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { - rel_type: Cow::Borrowed("m.thread"), - event_type_pattern: None, - sender: None, - sender_type: Some(Cow::Borrowed("user_id")), - })]), - actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), - default: true, - default_enabled: true, - }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index efe88ec76e..0365dd01dc 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - borrow::Cow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeMap; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -49,13 +46,6 @@ pub struct PushRuleEvaluator { /// The `notifications` section of the current power levels in the room. notification_power_levels: BTreeMap, - /// The relations related to the event as a mapping from relation type to - /// set of sender/event type 2-tuples. - relations: BTreeMap>, - - /// Is running "relation" conditions enabled? - relation_match_enabled: bool, - /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, @@ -70,8 +60,6 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - relations: BTreeMap>, - relation_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -83,8 +71,6 @@ impl PushRuleEvaluator { body, room_member_count, notification_power_levels, - relations, - relation_match_enabled, sender_power_level, }) } @@ -203,89 +189,11 @@ impl PushRuleEvaluator { false } } - KnownCondition::RelationMatch { - rel_type, - event_type_pattern, - sender, - sender_type, - } => { - self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? - } }; Ok(result) } - /// Evaluates a relation condition. - fn match_relations( - &self, - rel_type: &str, - sender: &Option>, - sender_type: &Option>, - user_id: Option<&str>, - event_type_pattern: &Option>, - ) -> Result { - // First check if relation matching is enabled... - if !self.relation_match_enabled { - return Ok(false); - } - - // ... and if there are any relations to match against. - let relations = if let Some(relations) = self.relations.get(rel_type) { - relations - } else { - return Ok(false); - }; - - // Extract the sender pattern from the condition - let sender_pattern = if let Some(sender) = sender { - Some(sender.as_ref()) - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - Some(user_id) - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - None - }; - - let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - for (relation_sender, event_type) in relations { - if let Some(pattern) = &mut sender_compiled_pattern { - if !pattern.is_match(relation_sender)? { - continue; - } - } - - if let Some(pattern) = &mut type_compiled_pattern { - if !pattern.is_match(event_type)? { - continue; - } - } - - return Ok(true); - } - - Ok(false) - } - /// Evaluates a `event_match` condition. fn match_event_match( &self, @@ -359,15 +267,8 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - Some(0), - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); + let evaluator = + PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 208b9c0d73..0dabfab8b8 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -275,16 +275,6 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, - #[serde(rename = "org.matrix.msc3772.relation_match")] - RelationMatch { - rel_type: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none", rename = "type")] - event_type_pattern: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender_type: Option>, - }, } impl IntoPy for Condition { @@ -401,21 +391,15 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3772_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new( - push_rules: PushRules, - enabled_map: BTreeMap, - msc3772_enabled: bool, - ) -> Self { + pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { Self { push_rules, enabled_map, - msc3772_enabled, } } @@ -430,25 +414,13 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules - .iter() - .filter(|rule| { - // Ignore disabled experimental push rules - if !self.msc3772_enabled - && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - { - return false; - } - - true - }) - .map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules.iter().map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 5900e61450..f2a61df660 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,9 +25,7 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3772_enabled: bool - ): ... + def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -39,8 +37,6 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - relations: Mapping[str, Set[Tuple[str, str]]], - relation_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e00cb7096c..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index eced182fd5..8d94aeaa32 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,16 +225,11 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7ee..a9f25a5904 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,9 +259,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e15827986..060fe71454 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2024,11 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) - self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), - ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 116abef9de..6b7eec4bf2 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -776,59 +776,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) - - sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} - """ - - def _get_event_relations( - txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result - - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) - @cached() async def get_thread_id(self, event_id: str) -> Optional[str]: """ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8804f0e0d3..decf619466 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Union import frozendict @@ -38,12 +38,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator( - self, - content: JsonDict, - relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, - relations_match_enabled: bool = False, - ) -> PushRuleEvaluator: + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -63,8 +58,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), - relations or {}, - relations_match_enabled, ) def test_display_name(self) -> None: @@ -299,71 +292,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) - def test_relation_match(self) -> None: - """Test the relation_match push rule kind.""" - - # Check if the experimental feature is disabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}} - ) - - # A push rule evaluator with the experimental rule enabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}}, True - ) - - # Check just relation type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and sender. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@user:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@other:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and event type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "type": "m.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check just sender, this fails since rel_type is required. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "sender": "@user:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check sender glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@*:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check event type glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "event_type": "*.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From f9bc5428c46e73ca471b6976865d5ba4168f938d Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 12 Oct 2022 11:36:22 +0100 Subject: Batch up calls to `get_rooms_for_users` (#14109) --- changelog.d/14109.misc | 1 + synapse/storage/databases/main/roommember.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14109.misc (limited to 'synapse') diff --git a/changelog.d/14109.misc b/changelog.d/14109.misc new file mode 100644 index 0000000000..7987c2050f --- /dev/null +++ b/changelog.d/14109.misc @@ -0,0 +1 @@ +Break up calls to fetch rooms for many users. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2337289d88..2ed6ad754f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cached_method_name="get_rooms_for_user", list_name="user_ids", ) - async def get_rooms_for_users( + async def _get_rooms_for_users( self, user_ids: Collection[str] ) -> Dict[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. @@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {key: frozenset(rooms) for key, rooms in user_rooms.items()} + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched wrapper around `_get_rooms_for_users`, to prevent locking + other calls to `get_rooms_for_user` for large user lists. + """ + all_user_rooms: Dict[str, FrozenSet[str]] = {} + + # 250 users is pretty arbitrary but the data can be quite large if users + # are in many rooms. + for user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(user_ids)) + + return all_user_rooms + @cached(max_entries=10000) async def does_pair_of_users_share_a_room( self, user_id: str, other_user_id: str -- cgit 1.5.1 From c604d2c218a80f169876cf3063817e038063f7b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 06:46:13 -0400 Subject: Mark /relations endpoint as usable on workers. (#14028) Co-authored-by: Eric Eastwood --- changelog.d/14028.feature | 1 + docker/complement/conf/start_for_complement.sh | 1 + docker/configure_workers_and_start.py | 27 ++++++++++++++++++++++++++ docs/workers.md | 1 + scripts-dev/complement.sh | 7 +++++-- synapse/app/generic_worker.py | 2 ++ 6 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14028.feature (limited to 'synapse') diff --git a/changelog.d/14028.feature b/changelog.d/14028.feature new file mode 100644 index 0000000000..6f5663a0ef --- /dev/null +++ b/changelog.d/14028.feature @@ -0,0 +1 @@ +The `/relations` endpoint can now be used on workers. diff --git a/docker/complement/conf/start_for_complement.sh b/docker/complement/conf/start_for_complement.sh index cc6482f763..bb85d9fed7 100755 --- a/docker/complement/conf/start_for_complement.sh +++ b/docker/complement/conf/start_for_complement.sh @@ -57,6 +57,7 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then federation_reader, \ federation_sender, \ synchrotron, \ + client_reader, \ appservice, \ pusher" diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 51583dc13d..8e7f605b24 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -107,6 +107,33 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "shared_extra_conf": {}, "worker_extra_conf": "", }, + "client_reader": { + "app": "synapse.app.generic_worker", + "listener_resources": ["client"], + "endpoint_patterns": [ + "^/_matrix/client/(api/v1|r0|v3|unstable)/publicRooms$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/joined_members$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$", + "^/_matrix/client/v1/rooms/.*/hierarchy$", + "^/_matrix/client/(v1|unstable)/rooms/.*/relations/", + "^/_matrix/client/(api/v1|r0|v3|unstable)/login$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$", + "^/_matrix/client/versions$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$", + "^/_matrix/client/(r0|v3|unstable)/register$", + "^/_matrix/client/(r0|v3|unstable)/auth/.*/fallback/web$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/messages$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event", + "^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms", + "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", + "^/_matrix/client/(api/v1|r0|v3|unstable)/search", + ], + "shared_extra_conf": {}, + "worker_extra_conf": "", + }, "federation_reader": { "app": "synapse.app.generic_worker", "listener_resources": ["federation"], diff --git a/docs/workers.md b/docs/workers.md index 27041ea57c..e8d6cbaf8b 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -203,6 +203,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/v1/rooms/.*/hierarchy$ + ^/_matrix/client/(v1|unstable)/rooms/.*/relations/ ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index eab23f18f1..a7b1e1e3a8 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -126,7 +126,7 @@ export COMPLEMENT_BASE_IMAGE=complement-synapse extra_test_args=() -test_tags="synapse_blacklist,msc2716,msc3030,msc3787" +test_tags="synapse_blacklist,msc3787" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) @@ -158,7 +158,10 @@ else # We only test faster room joins on monoliths, because they are purposefully # being developed without worker support to start with. - test_tags="$test_tags,faster_joins" + # + # The tests for importing historical messages (MSC2716) and jump to date (MSC3030) + # also only pass with monoliths, currently. + test_tags="$test_tags,faster_joins,msc2716,msc3030" fi diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 5e3825fca6..dc49840f73 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,6 +65,7 @@ from synapse.rest.client import ( push_rule, read_marker, receipts, + relations, room, room_batch, room_keys, @@ -308,6 +309,7 @@ class GenericWorkerServer(HomeServer): sync.register_servlets(self, resource) events.register_servlets(self, resource) room.register_servlets(self, resource, is_worker=True) + relations.register_servlets(self, resource) room.register_deprecated_servlets(self, resource) initial_sync.register_servlets(self, resource) room_batch.register_servlets(self, resource) -- cgit 1.5.1 From 9c23442ac909afe3d827534b00d52ee182d2f423 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 12 Oct 2022 14:37:20 +0100 Subject: Correct field name for stripped state events when knocking. `knock_state_events` -> `knock_room_state` (#14102) --- changelog.d/14102.bugfix | 1 + synapse/federation/federation_client.py | 2 +- synapse/federation/federation_server.py | 9 ++++++++- synapse/handlers/federation.py | 20 ++++++++++++++++---- 4 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14102.bugfix (limited to 'synapse') diff --git a/changelog.d/14102.bugfix b/changelog.d/14102.bugfix new file mode 100644 index 0000000000..d71e108f7c --- /dev/null +++ b/changelog.d/14102.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.37.0 in which an incorrect key name was used for sending and receiving room metadata when knocking on a room. \ No newline at end of file diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4dca711cd2..b220ab43fc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1294,7 +1294,7 @@ class FederationClient(FederationBase): return resp[1] async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict: - """Attempts to send a knock event to given a list of servers. Iterates + """Attempts to send a knock event to a given list of servers. Iterates through the list until one attempt succeeds. Doing so will cause the remote server to add the event to the graph, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 907940e19e..28097664b4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -824,7 +824,14 @@ class FederationServer(FederationBase): context, self._room_prejoin_state_types ) ) - return {"knock_state_events": stripped_room_state} + return { + "knock_room_state": stripped_room_state, + # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. + # Thus, we also populate a 'knock_state_events' with the same content to + # support old instances. + # See https://github.com/matrix-org/synapse/issues/14088. + "knock_state_events": stripped_room_state, + } async def _on_send_membership_event( self, origin: str, content: JsonDict, membership_type: str, room_id: str diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 986ffed3d5..44e70c6c3c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -781,15 +781,27 @@ class FederationHandler: # Send the signed event back to the room, and potentially receive some # further information about the room in the form of partial state events - stripped_room_state = await self.federation_client.send_knock( - target_hosts, event - ) + knock_response = await self.federation_client.send_knock(target_hosts, event) # Store any stripped room state events in the "unsigned" key of the event. # This is a bit of a hack and is cribbing off of invites. Basically we # store the room state here and retrieve it again when this event appears # in the invitee's sync stream. It is stripped out for all other local users. - event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] + stripped_room_state = ( + knock_response.get("knock_room_state") + # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. + # Thus, we also check for a 'knock_state_events' to support old instances. + # See https://github.com/matrix-org/synapse/issues/14088. + or knock_response.get("knock_state_events") + ) + + if stripped_room_state is None: + raise KeyError( + "Missing 'knock_room_state' (or legacy 'knock_state_events') field in " + "send_knock response" + ) + + event.unsigned["knock_room_state"] = stripped_room_state context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( -- cgit 1.5.1 From 87099b6ea5cb48b03d2007c46af80bc3f0767519 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 12:15:52 -0400 Subject: Return the main timeline for events which are not part of a thread. (#14140) Fixes a bug where threaded receipts could not be sent for the main timeline. --- changelog.d/14140.feature | 1 + synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/databases/main/relations.py | 12 +++++++----- 3 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14140.feature (limited to 'synapse') diff --git a/changelog.d/14140.feature b/changelog.d/14140.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14140.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8d94aeaa32..a75386f6a0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -236,7 +236,7 @@ class BulkPushRuleEvaluator: else: # Since the event has not yet been persisted we check whether # the parent is part of a thread. - thread_id = await self.store.get_thread_id(relation.parent_id) or "main" + thread_id = await self.store.get_thread_id(relation.parent_id) # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 6b7eec4bf2..e7fbf950e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -28,7 +28,7 @@ from typing import ( import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause @@ -777,7 +777,7 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_thread_id(self, event_id: str) -> Optional[str]: + async def get_thread_id(self, event_id: str) -> str: """ Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. @@ -787,7 +787,7 @@ class RelationsWorkerStore(SQLBaseStore): Returns: The event ID of the root event in the thread, if this event is part - of a thread. None, otherwise. + of a thread. "main", otherwise. """ # Since event relations form a tree, we should only ever find 0 or 1 # results from the below query. @@ -802,13 +802,15 @@ class RelationsWorkerStore(SQLBaseStore): ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; """ - def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] - return None + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) -- cgit 1.5.1 From e6e876b9b158f47811b6dfedd8783f658ce960a4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 12:18:34 -0400 Subject: Return the thread ID properly down sync. (#14159) A receipt's thread ID, if one exists, should be added to the body of a receipt. --- changelog.d/14159.feature | 1 + synapse/storage/databases/main/receipts.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/14159.feature (limited to 'synapse') diff --git a/changelog.d/14159.feature b/changelog.d/14159.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14159.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 246f78ac1f..b04026c21b 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -416,6 +416,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] receipt_type[row["user_id"]] = db_to_json(row["data"]) -- cgit 1.5.1 From b6baa46db078c3ef9e6c5751bccb8d2e1c5c5402 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 12 Oct 2022 11:01:00 -0700 Subject: Fix a bug where the joined hosts for a given event were not being properly cached (#14125) --- changelog.d/14125.bugfix | 1 + synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 91 +++++++++++++++++++----------------- 3 files changed, 51 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14125.bugfix (limited to 'synapse') diff --git a/changelog.d/14125.bugfix b/changelog.d/14125.bugfix new file mode 100644 index 0000000000..852f00ebb9 --- /dev/null +++ b/changelog.d/14125.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.69.0rc1 where the joined hosts for a given event were not being properly cached. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index da319943cc..f382961099 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -414,7 +414,9 @@ class FederationEventHandler: # First, precalculate the joined hosts so that the federation sender doesn't # need to. - await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._event_creation_handler.cache_joined_hosts_for_events( + [(event, context)] + ) await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da1acea275..4e55ebba0b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1390,7 +1390,7 @@ class EventCreationHandler: extra_users=extra_users, ), run_in_background( - self.cache_joined_hosts_for_event, event, context + self.cache_joined_hosts_for_events, events_and_context ).addErrback( log_failure, "cache_joined_hosts_for_event failed" ), @@ -1491,62 +1491,65 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise - async def cache_joined_hosts_for_event( - self, event: EventBase, context: EventContext + async def cache_joined_hosts_for_events( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Precalculate the joined hosts at the event, when using Redis, so that + """Precalculate the joined hosts at each of the given events, when using Redis, so that external federation senders don't have to recalculate it themselves. """ - if not self._external_cache.is_enabled(): - return - - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + for event, _ in events_and_context: + if not self._external_cache.is_enabled(): + return - # We actually store two mappings, event ID -> prev state group, - # state group -> joined hosts, which is much more space efficient - # than event ID -> joined hosts. - # - # Note: We have to cache event ID -> prev state group, as we don't - # store that in the DB. - # - # Note: We set the state group -> joined hosts cache if it hasn't been - # set for a while, so that the expiry time is reset. + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None - state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() - ) + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We set the state group -> joined hosts cache if it hasn't been + # set for a while, so that the expiry time is reset. - if state_entry.state_group: - await self._external_cache.set( - "event_to_prev_state_group", - event.event_id, - state_entry.state_group, - expiry_ms=60 * 60 * 1000, + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() ) - if state_entry.state_group in self._external_cache_joined_hosts_updates: - return + if state_entry.state_group: + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) - with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + if state_entry.state_group in self._external_cache_joined_hosts_updates: + return + + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) - # Note that the expiry times must be larger than the expiry time in - # _external_cache_joined_hosts_updates. - await self._external_cache.set( - "get_joined_hosts", - str(state_entry.state_group), - list(joined_hosts), - expiry_ms=60 * 60 * 1000, - ) + # Note that the expiry times must be larger than the expiry time in + # _external_cache_joined_hosts_updates. + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) - self._external_cache_joined_hosts_updates[state_entry.state_group] = None + self._external_cache_joined_hosts_updates[ + state_entry.state_group + ] = None async def _validate_canonical_alias( self, -- cgit 1.5.1 From 3bbe532abb7bfc41467597731ac1a18c0331f539 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 08:02:11 -0400 Subject: Add an API for listing threads in a room. (#13394) Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes. --- changelog.d/13394.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/relations.py | 86 ++++++++++- synapse/rest/client/relations.py | 50 ++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 38 ++++- synapse/storage/databases/main/relations.py | 166 ++++++++++++++++++++- .../schema/main/delta/73/09threads_table.sql | 30 ++++ tests/rest/client/test_relations.py | 151 +++++++++++++++++++ 10 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13394.feature create mode 100644 synapse/storage/schema/main/delta/73/09threads_table.sql (limited to 'synapse') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature new file mode 100644 index 0000000000..68de079cf3 --- /dev/null +++ b/changelog.d/13394.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5fa599e70e..d850e54e17 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..1860006536 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,6 +101,9 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + # MSC3856: Threads list API + self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) + # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index cc5e45c241..1fdd7a10bc 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -32,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -482,3 +490,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b31ce5a0d3..d1aa1947a5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,12 +13,15 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.streams.config import PaginationConfig from synapse.types import JsonDict @@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), + limit=limit, + from_token=from_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + if hs.config.experimental.msc3856_enabled: + ThreadsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a9f25a5904..0ce3156c9c 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 060fe71454..6698cbf664 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,13 +2017,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2024,6 +2053,9 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_threads, (room_id,) + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e7fbf950e6..ac9b96ab44 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -29,17 +30,46 @@ from typing import ( import attr from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. + """ + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) + + sql = f""" + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) + + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + @cached() async def get_thread_id(self, event_id: str) -> str: """ diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644 index 0000000000..aa7c5e9a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); + +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 988cdb746d..d595295e2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_ignored_user(self) -> None: + """Events from ignored users should be ignored.""" + # Thread 1 has a reply from an ignored user. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 is created by an ignored user. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Ignore user2. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Only thread 1 is returned. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) -- cgit 1.5.1 From 7d59a515bb97dc4f8253aa9a5a560221a0ef4702 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 12:15:41 -0400 Subject: Properly return the thread ID down sync. (#14159) Fix a broken conflict in e6e876b9b158f47811b6dfedd8783f658ce960a4, by not stomping over a field right after creating it. --- synapse/storage/databases/main/receipts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b04026c21b..dc6989527e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -416,10 +416,10 @@ class ReceiptsWorkerStore(SQLBaseStore): # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) - if row["thread_id"]: - receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] receipt_type[row["user_id"]] = db_to_json(row["data"]) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] results = { room_id: [results[room_id]] if room_id in results else [] -- cgit 1.5.1 From 2019b60f3bb5a505fc730f38a4b1accbabe444bf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 12:53:24 -0400 Subject: Fix sqlite syntax for upserts. (#14171) --- changelog.d/14171.feature | 1 + synapse/storage/databases/main/relations.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14171.feature (limited to 'synapse') diff --git a/changelog.d/14171.feature b/changelog.d/14171.feature new file mode 100644 index 0000000000..68de079cf3 --- /dev/null +++ b/changelog.d/14171.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ac9b96ab44..7c54ce0b2e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -138,7 +138,7 @@ class RelationsWorkerStore(SQLBaseStore): if isinstance(txn.database_engine, PostgresEngine): txn.execute_values(sql % ("?",), rows, fetch=False) else: - txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows) + txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) # Mark the progress. self.db_pool.updates._background_update_progress_txn( -- cgit 1.5.1 From 16c5d95b594e4fe146947c4848057ebe0b9f900b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 13 Oct 2022 18:32:16 +0100 Subject: Optimise the event_push_backfill_thread_id bg job (#14172) Co-authored-by: Erik Johnston --- changelog.d/14172.bugfix | 1 + synapse/storage/databases/main/event_push_actions.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14172.bugfix (limited to 'synapse') diff --git a/changelog.d/14172.bugfix b/changelog.d/14172.bugfix new file mode 100644 index 0000000000..36521c670c --- /dev/null +++ b/changelog.d/14172.bugfix @@ -0,0 +1 @@ +Fix poor performance of the `event_push_backfill_thread_id` background update, which was introduced in Synapse 1.68.0rc1. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 87d07f7d9b..7f7bcb7094 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -297,9 +297,15 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas sql = f""" UPDATE {table_name} SET thread_id = 'main' - WHERE stream_ordering <= ? AND thread_id IS NULL + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL """ - txn.execute(sql, (max_stream_ordering,)) + txn.execute( + sql, + ( + start_stream_ordering, + max_stream_ordering, + ), + ) # Update progress. processed_rows = txn.rowcount -- cgit 1.5.1 From 9ff4155f6cc9fc0b7aff82da9f0a1cae677dbda5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 07:10:44 -0400 Subject: Properly invalidate get_thread_id cache. (#14163) This was missed in 2b6d41ebd685fb546e52acdbcb0024dfcf5a5db1 (#13824). --- changelog.d/14163.feature | 1 + synapse/storage/databases/main/cache.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/14163.feature (limited to 'synapse') diff --git a/changelog.d/14163.feature b/changelog.d/14163.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14163.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 0ce3156c9c..b47fc606c7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -244,6 +244,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) -- cgit 1.5.1 From c3e4edb4d6ba33383bc056e3ff22b2d034d3e248 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 07:16:50 -0400 Subject: Stabilize the threads API. (#14175) Stabilize the threads API (MSC3856) by supporting (only) the v1 path for the endpoint. This also marks the API as safe for workers since it is a read-only API. --- changelog.d/13394.feature | 2 +- changelog.d/14175.feature | 1 + docker/configure_workers_and_start.py | 1 + docs/workers.md | 1 + synapse/config/experimental.py | 3 --- synapse/rest/client/relations.py | 9 ++----- tests/rest/client/test_relations.py | 47 +++++++++++++++++++++-------------- 7 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 changelog.d/14175.feature (limited to 'synapse') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature index 68de079cf3..df3ce45a76 100644 --- a/changelog.d/13394.feature +++ b/changelog.d/13394.feature @@ -1 +1 @@ -Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/changelog.d/14175.feature b/changelog.d/14175.feature new file mode 100644 index 0000000000..df3ce45a76 --- /dev/null +++ b/changelog.d/14175.feature @@ -0,0 +1 @@ +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 8e7f605b24..d708237f69 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -118,6 +118,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$", "^/_matrix/client/v1/rooms/.*/hierarchy$", "^/_matrix/client/(v1|unstable)/rooms/.*/relations/", + "^/_matrix/client/v1/rooms/.*/threads$", "^/_matrix/client/(api/v1|r0|v3|unstable)/login$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$", diff --git a/docs/workers.md b/docs/workers.md index e8d6cbaf8b..c27b3f8bd5 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -204,6 +204,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/v1/rooms/.*/hierarchy$ ^/_matrix/client/(v1|unstable)/rooms/.*/relations/ + ^/_matrix/client/v1/rooms/.*/threads$ ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1860006536..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,9 +101,6 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) - # MSC3856: Threads list API - self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) - # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d1aa1947a5..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -82,11 +82,7 @@ class RelationPaginationServlet(RestServlet): class ThreadsServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" - ), - ) + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/threads"),) def __init__(self, hs: "HomeServer"): super().__init__() @@ -126,5 +122,4 @@ class ThreadsServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - if hs.config.experimental.msc3856_enabled: - ThreadsServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d595295e2c..f5c1070b2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1710,7 +1710,15 @@ class RelationRedactionTestCase(BaseRelationsTestCase): class ThreadsTestCase(BaseRelationsTestCase): - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]: + return [ + ( + ev["event_id"], + ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"], + ) + for ev in body["chunk"] + ] + def test_threads(self) -> None: """Create threads and ensure the ordering is due to their latest event.""" # Create 2 threads. @@ -1718,32 +1726,37 @@ class ThreadsTestCase(BaseRelationsTestCase): res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) thread_2 = res["event_id"] - self._send_relation(RelationTypes.THREAD, "m.room.test") - self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_1 = channel.json_body["event_id"] + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", parent_id=thread_2 + ) + reply_2 = channel.json_body["event_id"] # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_2, thread_1]) + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) # Update the first thread, the ordering should swap. - self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_3 = channel.json_body["event_id"] channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_1, thread_2]) + # Tuple of (thread ID, latest event ID) for each thread. + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_pagination(self) -> None: """Create threads and paginate through them.""" # Create 2 threads. @@ -1757,7 +1770,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1771,7 +1784,7 @@ class ThreadsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1780,7 +1793,6 @@ class ThreadsTestCase(BaseRelationsTestCase): self.assertNotIn("next_batch", channel.json_body, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_include(self) -> None: """Filtering threads to all or participated in should work.""" # Thread 1 has the user as the root event. @@ -1807,7 +1819,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # All threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1819,14 +1831,13 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only participated threads. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_ignored_user(self) -> None: """Events from ignored users should be ignored.""" # Thread 1 has a reply from an ignored user. @@ -1852,7 +1863,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only thread 1 is returned. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) -- cgit 1.5.1 From 126a15794c95002560709283640ad412636b29b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 08:30:05 -0400 Subject: Do not allow a None-limit on PaginationConfig. (#14146) The callers either set a default limit or manually handle a None-limit later on (by setting a default value). Update the callers to always instantiate PaginationConfig with a default limit and then assume the limit is non-None. --- changelog.d/14146.removal | 1 + synapse/handlers/account_data.py | 2 +- synapse/handlers/initial_sync.py | 27 ++++----------------------- synapse/handlers/pagination.py | 5 ----- synapse/handlers/presence.py | 4 +++- synapse/handlers/receipts.py | 2 +- synapse/handlers/relations.py | 3 --- synapse/handlers/room.py | 2 +- synapse/handlers/typing.py | 2 +- synapse/rest/client/events.py | 4 +++- synapse/rest/client/initial_sync.py | 4 +++- synapse/rest/client/room.py | 4 +++- synapse/storage/databases/main/stream.py | 2 -- synapse/streams/__init__.py | 2 +- synapse/streams/config.py | 12 +++++------- tests/rest/client/test_typing.py | 3 ++- 16 files changed, 29 insertions(+), 50 deletions(-) create mode 100644 changelog.d/14146.removal (limited to 'synapse') diff --git a/changelog.d/14146.removal b/changelog.d/14146.removal new file mode 100644 index 0000000000..08fa752897 --- /dev/null +++ b/changelog.d/14146.removal @@ -0,0 +1 @@ +Remove the unstable identifier for [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 0478448b47..fc21d58001 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 860c82c110..9c335e6863 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -57,13 +57,7 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, - Optional[StreamToken], - Optional[StreamToken], - str, - Optional[int], - bool, - bool, + str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -154,11 +148,6 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - if pagin_config.limit is not None: - limit = pagin_config.limit - else: - limit = 10 - serializer_options = SerializeEventConfig(as_client_event=as_client_event) async def handle_room(event: RoomsForUser) -> None: @@ -210,7 +199,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, event.room_id, - limit=limit, + limit=pagin_config.limit, end_token=room_end_token, ), deferred_room_state, @@ -360,15 +349,11 @@ class InitialSyncHandler: member_event_id ) - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - leave_position = await self.store.get_position_for_event(member_event_id) stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( - room_id, limit=limit, end_token=stream_token + room_id, limit=pagin_config.limit, end_token=stream_token ) messages = await filter_events_for_client( @@ -420,10 +405,6 @@ class InitialSyncHandler: now_token = self.hs.get_event_sources().get_current_token() - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - room_members = [ m for m in current_state.values() @@ -467,7 +448,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, room_id, - limit=limit, + limit=pagin_config.limit, end_token=now_token.room_key, ), ), diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1f83bab836..a4ca9cb8b4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -458,11 +458,6 @@ class PaginationHandler: # `/messages` should still works with live tokens when manually provided. assert from_token.room_key.topological is not None - if pagin_config.limit is None: - # This shouldn't happen as we've set a default limit before this - # gets called. - raise Exception("limit not set") - room_token = from_token.room_key async with self.pagination_lock.read(room_id): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4e575ffbaa..2670e561d7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user: UserID, from_key: Optional[int], - limit: Optional[int] = None, + # Having a default limit doesn't match the EventSource API, but some + # callers do not provide it. It is unused in this class. + limit: int = 0, room_ids: Optional[Collection[str]] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4a7ec9e426..ac01582442 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 1fdd7a10bc..0a0c6d938e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -116,9 +116,6 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - # TODO Update pagination config to not allow None limits. - assert pagin_config.limit is not None - # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 57ab05ad25..4e1aacb408 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): self, user: UserID, from_key: RoomStreamToken, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f953691669..a0ea719430 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 916f5230f1..782e7d14e8 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet): raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") - pagin_config = await PaginationConfig.from_request(self.store, request) + pagin_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in args: try: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index cfadcb8e50..9b1bb8b521 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) args: Dict[bytes, List[bytes]] = request.args # type: ignore as_client_event = b"raw" not in args - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index b6dedbed04..01e5079963 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ffeb2b3683..5baffbfe55 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1200,8 +1200,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 806b671305..2dcd43d0a2 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -27,7 +27,7 @@ class EventSource(Generic[K, R]): self, user: UserID, from_key: K, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index f6f7bf3d8b..6df2de919c 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -35,14 +35,14 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] direction: str - limit: Optional[int] + limit: int @classmethod async def from_request( cls, store: "DataStore", request: SynapseRequest, - default_limit: Optional[int] = None, + default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": direction = parse_string( @@ -69,12 +69,10 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") - if limit: - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - - limit = min(int(limit), MAX_LIMIT) + limit = min(limit, MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 61b66d7685..fdc433a8b5 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=UserID.from_string(self.user_id), from_key=0, - limit=None, + # Limit is unused. + limit=0, room_ids=[self.room_id], is_guest=False, ) -- cgit 1.5.1 From 97b3d037c043d5c91c2a36109cab0c668a6a13ed Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 14 Oct 2022 13:48:33 +0100 Subject: Don't require optional `invite_room_state` field on fed v2 invite (#14083) --- changelog.d/14083.bugfix | 1 + synapse/federation/transport/server/federation.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14083.bugfix (limited to 'synapse') diff --git a/changelog.d/14083.bugfix b/changelog.d/14083.bugfix new file mode 100644 index 0000000000..752982b1ca --- /dev/null +++ b/changelog.d/14083.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would error on the optional 'invite_room_state' field not being provided to [`PUT /_matrix/federation/v2/invite/{roomId}/{eventId}`](https://spec.matrix.org/v1.4/server-server-api/#put_matrixfederationv2inviteroomideventid). \ No newline at end of file diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 6bb4659c4c..6f11138b57 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -489,7 +489,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet): room_version = content["room_version"] event = content["event"] - invite_room_state = content["invite_room_state"] + invite_room_state = content.get("invite_room_state", []) # Synapse expects invite_room_state to be in unsigned, as it is in v1 # API -- cgit 1.5.1 From 022f25b3090f7f3a494cecb398bfdbbc2488c2bf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 09:21:55 -0400 Subject: Advertise support for Matrix 1.4. (#14184) All features / changes in Matrix 1.4 are now supported in Synapse. --- changelog.d/14032.feature | 2 +- changelog.d/14184.feature | 1 + synapse/rest/client/versions.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14184.feature (limited to 'synapse') diff --git a/changelog.d/14032.feature b/changelog.d/14032.feature index bb221d3ca6..016c704227 100644 --- a/changelog.d/14032.feature +++ b/changelog.d/14032.feature @@ -1 +1 @@ -Advertise Matrix 1.3 support on `/_matrix/client/versions`. +Advertise support for Matrix 1.3 and 1.4 on `/_matrix/client/versions`. diff --git a/changelog.d/14184.feature b/changelog.d/14184.feature new file mode 100644 index 0000000000..016c704227 --- /dev/null +++ b/changelog.d/14184.feature @@ -0,0 +1 @@ +Advertise support for Matrix 1.3 and 1.4 on `/_matrix/client/versions`. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index d1d2e5f7e3..4e1fd2bbe7 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -76,6 +76,7 @@ class VersionsRestServlet(RestServlet): "v1.1", "v1.2", "v1.3", + "v1.4", ], # as per MSC1497: "unstable_features": { -- cgit 1.5.1 From d241a1350d5b0e1cf8262114f0cb34325cb91a26 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 14 Oct 2022 14:46:23 +0100 Subject: Fix background update to use an index (#14181) --- changelog.d/14181.bugfix | 1 + .../storage/databases/main/event_push_actions.py | 62 ++++++++++++++++++---- 2 files changed, 52 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14181.bugfix (limited to 'synapse') diff --git a/changelog.d/14181.bugfix b/changelog.d/14181.bugfix new file mode 100644 index 0000000000..36521c670c --- /dev/null +++ b/changelog.d/14181.bugfix @@ -0,0 +1 @@ +Fix poor performance of the `event_push_backfill_thread_id` background update, which was introduced in Synapse 1.68.0rc1. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7f7bcb7094..72cf91eb39 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -269,11 +269,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_push_actions_done = progress.get("event_push_actions_done", False) def add_thread_id_txn( - txn: LoggingTransaction, table_name: str, start_stream_ordering: int + txn: LoggingTransaction, start_stream_ordering: int ) -> int: - sql = f""" + sql = """ SELECT stream_ordering - FROM {table_name} + FROM event_push_actions WHERE thread_id IS NULL AND stream_ordering > ? @@ -285,7 +285,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # No more rows to process. rows = txn.fetchall() if not rows: - progress[f"{table_name}_done"] = True + progress["event_push_actions_done"] = True self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -294,8 +294,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Update the thread ID for any of those rows. max_stream_ordering = rows[-1][0] - sql = f""" - UPDATE {table_name} + sql = """ + UPDATE event_push_actions SET thread_id = 'main' WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL """ @@ -309,7 +309,50 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Update progress. processed_rows = txn.rowcount - progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + progress["max_event_push_actions_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + def add_thread_id_summary_txn(txn: LoggingTransaction) -> int: + min_user_id = progress.get("max_summary_user_id", "") + min_room_id = progress.get("max_summary_room_id", "") + + # Slightly overcomplicated query for getting the Nth user ID / room + # ID tuple, or the last if there are less than N remaining. + sql = """ + SELECT user_id, room_id FROM ( + SELECT user_id, room_id FROM event_push_summary + WHERE (user_id, room_id) > (?, ?) + AND thread_id IS NULL + ORDER BY user_id, room_id + LIMIT ? + ) AS e + ORDER BY user_id DESC, room_id DESC + LIMIT 1 + """ + + txn.execute(sql, (min_user_id, min_room_id, batch_size)) + row = txn.fetchone() + if not row: + return 0 + + max_user_id, max_room_id = row + + sql = """ + UPDATE event_push_summary + SET thread_id = 'main' + WHERE + (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?) + AND thread_id IS NULL + """ + txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id)) + processed_rows = txn.rowcount + + progress["max_summary_user_id"] = max_user_id + progress["max_summary_room_id"] = max_room_id self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -325,15 +368,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", add_thread_id_txn, - "event_push_actions", progress.get("max_event_push_actions_stream_ordering", 0), ) else: result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", - add_thread_id_txn, - "event_push_summary", - progress.get("max_event_push_summary_stream_ordering", 0), + add_thread_id_summary_txn, ) # Only done after the event_push_summary table is done. -- cgit 1.5.1 From d1bdeccb50550ef454067aa01dd9d004c4704633 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 14:05:25 -0400 Subject: Accept threaded receipts for events related to the root event. (#14174) The root node of a thread (and events related to it) are considered "part of a thread" when validating receipts. This allows clients which show the root node in both the main timeline and the threaded timeline to easily send receipts in either. Note that threaded notifications are not created for these events, these events created notifications on the main timeline. --- changelog.d/14174.feature | 1 + synapse/rest/client/receipts.py | 44 ++++++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/relations.py | 98 ++++++++++++++++++++++-- tests/storage/test_relations.py | 111 ++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14174.feature create mode 100644 tests/storage/test_relations.py (limited to 'synapse') diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14174.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 14dec7ac4e..18a282b22c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet): ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet): return 200, {} + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: + """ + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. + + A thread ID is considered valid for a given event E if: + + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which + satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main timeline on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread, otherwise False. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index b47fc606c7..ed0be4abe5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7c54ce0b2e..1de62ee9df 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore): Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. @@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore): The event ID of the root event in the thread, if this event is part of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] @@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py new file mode 100644 index 0000000000..cd1d00208b --- /dev/null +++ b/tests/storage/test_relations.py @@ -0,0 +1,111 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import MAIN_TIMELINE +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class RelationsStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + """ + Creates a DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + F <--[m.annotation]-- G + + """ + self._main_store = self.hs.get_datastores().main + + self._create_relation("A", "B", "m.thread") + self._create_relation("B", "C", "m.annotation") + self._create_relation("A", "D", "m.reference") + self._create_relation("D", "E", "m.annotation") + self._create_relation("F", "G", "m.annotation") + + def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None: + self.get_success( + self._main_store.db_pool.simple_insert( + table="event_relations", + values={ + "event_id": event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + }, + ) + ) + + def test_get_thread_id(self) -> None: + """ + Ensure that get_thread_id only searches up the tree for threads. + """ + # The thread itself and children of it return the thread. + thread_id = self.get_success(self._main_store.get_thread_id("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("C")) + self.assertEqual("A", thread_id) + + # But the root and events related to the root do not. + thread_id = self.get_success(self._main_store.get_thread_id("A")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("D")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("E")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + def test_get_thread_id_for_receipts(self) -> None: + """ + Ensure that get_thread_id_for_receipts searches up and down the tree for a thread. + """ + # All of the events are considered related to this thread. + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E")) + self.assertEqual("A", thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) -- cgit 1.5.1 From 40bb37eb27e1841754a297ac1277748de7f6c1cb Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Sat, 15 Oct 2022 00:36:49 -0500 Subject: Stop getting missing `prev_events` after we already know their signature is invalid (#13816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While https://github.com/matrix-org/synapse/pull/13635 stops us from doing the slow thing after we've already done it once, this PR stops us from doing one of the slow things in the first place. Related to - https://github.com/matrix-org/synapse/issues/13622 - https://github.com/matrix-org/synapse/pull/13635 - https://github.com/matrix-org/synapse/issues/13676 Part of https://github.com/matrix-org/synapse/issues/13356 Follow-up to https://github.com/matrix-org/synapse/pull/13815 which tracks event signature failures. With this PR, we avoid the call to the costly `_get_state_ids_after_missing_prev_event` because the signature failure will count as an attempt before and we filter events based on the backoff before calling `_get_state_ids_after_missing_prev_event` now. For example, this will save us 156s out of the 185s total that this `matrix.org` `/messages` request. If you want to see the full Jaeger trace of this, you can drag and drop this `trace.json` into your own Jaeger, https://gist.github.com/MadLittleMods/4b12d0d0afe88c2f65ffcc907306b761 To explain this exact scenario around `/messages` -> backfill, we call `/backfill` and first check the signatures of the 100 events. We see bad signature for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` and `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` (both member events). Then we process the 98 events remaining that have valid signatures but one of the events references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event`. So we have to do the whole `_get_state_ids_after_missing_prev_event` rigmarole which pulls in those same events which fail again because the signatures are still invalid. - `backfill` - `outgoing-federation-request` `/backfill` - `_check_sigs_and_hash_and_fetch` - `_check_sigs_and_hash_and_fetch_one` for each event received over backfill - ❗ `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - ❗ `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - `_process_pulled_events` - `_process_pulled_event` for each validated event - ❗ Event `$Q0iMdqtz3IJYfZQU2Xk2WjB5NDF8Gg8cFSYYyKQgKJ0` references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event` which is missing so we try to get it - `_get_state_ids_after_missing_prev_event` - `outgoing-federation-request` `/state_ids` - ❗ `get_pdu` for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` which fails the signature check again - ❗ `get_pdu` for `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` which fails the signature check --- changelog.d/13816.feature | 1 + synapse/api/errors.py | 21 +++ synapse/handlers/federation.py | 16 ++ synapse/handlers/federation_event.py | 31 ++++ synapse/storage/databases/main/event_federation.py | 54 ++++++ tests/handlers/test_federation_event.py | 201 ++++++++++++++++++++- tests/storage/test_event_federation.py | 64 +++++++ 7 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13816.feature (limited to 'synapse') diff --git a/changelog.d/13816.feature b/changelog.d/13816.feature new file mode 100644 index 0000000000..5eaa936b08 --- /dev/null +++ b/changelog.d/13816.feature @@ -0,0 +1 @@ +Stop fetching missing `prev_events` after we already know their signature is invalid. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c606207569..e0873b1913 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -640,6 +640,27 @@ class FederationError(RuntimeError): } +class FederationPullAttemptBackoffError(RuntimeError): + """ + Raised to indicate that we are are deliberately not attempting to pull the given + event over federation because we've already done so recently and are backing off. + + Attributes: + event_id: The event_id which we are refusing to pull + message: A custom error message that gives more context + """ + + def __init__(self, event_ids: List[str], message: Optional[str]): + self.event_ids = event_ids + + if message: + error_message = message + else: + error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + + super().__init__(error_message) + + class HttpResponseException(CodeMessageException): """ Represents an HTTP-level failure of an outbound request diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 44e70c6c3c..5f7e0a1f79 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, LimitExceededError, NotFoundError, @@ -1720,7 +1721,22 @@ class FederationHandler: destination, event ) break + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_sync_partial_state_room: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 + if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f382961099..4300e8dd40 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -44,6 +44,7 @@ from synapse.api.errors import ( AuthError, Codes, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, RequestSendFailed, SynapseError, @@ -567,6 +568,9 @@ class FederationEventHandler: event: partial-state event to be de-partial-stated Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) @@ -901,6 +905,18 @@ class FederationEventHandler: context, backfilled=backfilled, ) + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_process_pulled_event: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: await self._store.record_event_failed_pull_attempt( event.room_id, event_id, str(e) @@ -947,6 +963,9 @@ class FederationEventHandler: The event context. Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -957,6 +976,18 @@ class FederationEventHandler: seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen + # If we've already recently attempted to pull this missing event, don't + # try it again so soon. Since we have to fetch all of the prev_events, we can + # bail early here if we find any to ignore. + prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) + if len(prevs_to_ignore) > 0: + raise FederationPullAttemptBackoffError( + event_ids=prevs_to_ignore, + message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + ) + if not missing_prevs: return await self._state_handler.compute_event_context(event) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6b9a629edd..309a4ba664 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 918010cddb..e448cb1901 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -14,7 +14,7 @@ from typing import Optional from unittest import mock -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion from synapse.event_auth import ( check_state_dependent_auth_rules, @@ -43,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): def make_homeserver(self, reactor, clock): # mock out the federation transport client self.mock_federation_transport_client = mock.Mock( - spec=["get_room_state_ids", "get_room_state", "get_event"] + spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client @@ -459,6 +459,203 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) self.assertIsNotNone(persisted, "pulled event was not persisted at all") + def test_backfill_signature_failure_does_not_fetch_same_prev_event_later( + self, + ) -> None: + """ + Test to make sure we backoff and don't try to fetch a missing prev_event when we + already know it has a invalid signature from checking the signatures of all of + the events in the backfill response. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # Add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + member_event.event_id, + ] + + # We purposely don't run `add_hashes_and_signatures_from_other_server` + # over this because we want the signature check to fail. + pulled_event_without_signatures = make_event_from_dict( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [member_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event_without_signatures"}, + }, + room_version, + ) + + # Create a regular event that should pass except for the + # `pulled_event_without_signatures` in the `prev_event`. + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + member_event.event_id, + pulled_event_without_signatures.event_id, + ], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event"}, + } + ), + room_version, + ) + + # We expect an outbound request to /backfill, so stub that out + self.mock_federation_transport_client.backfill.return_value = make_awaitable( + { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } + ) + + # Keep track of the count and make sure we don't make any of these requests + event_endpoint_requested_count = 0 + room_state_ids_endpoint_requested_count = 0 + room_state_endpoint_requested_count = 0 + + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> None: + nonlocal event_endpoint_requested_count + event_endpoint_requested_count += 1 + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_ids_endpoint_requested_count + room_state_ids_endpoint_requested_count += 1 + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_endpoint_requested_count + room_state_endpoint_requested_count += 1 + + # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in + # the happy path but if the logic is sneaking around what we expect, stub that + # out so we can detect that failure + self.mock_federation_transport_client.get_event.side_effect = get_event + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + # The function under test: try to backfill and process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler().backfill( + self.OTHER_SERVER_NAME, + room_id, + limit=1, + extremities=["$some_extremity"], + ) + ) + + if event_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /event in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_ids_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state_ids in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + # Make sure we only recorded a single failure which corresponds to the signature + # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before + # we process all of the pulled events. + backfill_num_attempts_for_event_without_signatures = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event_without_signatures.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1) + + # And make sure we didn't record a failure for the event that has the missing + # prev_event because we don't want to cause a cascade of failures. Not being + # able to fetch the `prev_events` just means we won't be able to de-outlier the + # pulled event. But we can still use an `outlier` in the state/auth chain for + # another event. So we shouldn't stop a downstream event from trying to pull it. + self.get_failure( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ), + # StoreError: 404: No row found + StoreError, + ) + def test_process_pulled_event_with_rejected_missing_state(self) -> None: """Ensure that we correctly handle pulled events with missing state containing a rejected state event diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 59b8910907..853db930d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,6 +27,8 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.events import _EventInternalMetadata +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -43,6 +45,12 @@ class _BackfillSetupInfo: class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -1122,6 +1130,62 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["insertion_eventA"]) + def test_get_event_ids_to_not_pull_from_backoff( + self, + ): + """ + Test to make sure only event IDs we should backoff from are returned. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + + self.assertEqual(event_ids_to_backoff, ["$failed_event_id"]) + + def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( + self, + ): + """ + Test to make sure no event IDs are returned after the backoff duration has + elapsed. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + # Now advance time by 2 hours so we wait long enough for the single failed + # attempt (2^1 hours). + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + # Since this function only returns events we should backoff from, time has + # elapsed past the backoff range so there is no events to backoff from. + self.assertEqual(event_ids_to_backoff, []) + @attr.s class FakeEvent: -- cgit 1.5.1 From 2c2c3f8b2c1e33d5aee6d480c60c75c1179e3dba Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 17 Oct 2022 13:27:51 +0100 Subject: Invalidate rooms for user caches when receiving membership events (#14155) This should fix a race where the event notification comes in over replication before the state replication, leaving a window during which a sync may get an incorrect list of rooms for the user. --- changelog.d/14155.misc | 1 + synapse/storage/databases/main/cache.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/14155.misc (limited to 'synapse') diff --git a/changelog.d/14155.misc b/changelog.d/14155.misc new file mode 100644 index 0000000000..79539cdc32 --- /dev/null +++ b/changelog.d/14155.misc @@ -0,0 +1 @@ +Invalidate rooms for user caches on replicated event, fix sync cache race in synapse workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ed0be4abe5..ddb7397714 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -252,6 +252,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (state_key,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) -- cgit 1.5.1 From ccce8cdfc5e567b5b905b58e82a1d725f2647524 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 17 Oct 2022 13:39:12 +0100 Subject: Use Pydantic when PUTting room aliases (#14179) --- changelog.d/14179.feature | 1 + synapse/handlers/directory.py | 19 +++++++------ synapse/rest/client/directory.py | 58 ++++++++++++++++++++++++---------------- 3 files changed, 47 insertions(+), 31 deletions(-) create mode 100644 changelog.d/14179.feature (limited to 'synapse') diff --git a/changelog.d/14179.feature b/changelog.d/14179.feature new file mode 100644 index 0000000000..48f2db91d3 --- /dev/null +++ b/changelog.d/14179.feature @@ -0,0 +1 @@ +Improve the validation of the following PUT endpoints: [`/directory/room/{roomAlias}`](https://spec.matrix.org/v1.4/client-server-api/#put_matrixclientv3directoryroomroomalias), [`/directory/list/room/{roomId}`](https://spec.matrix.org/v1.4/client-server-api/#put_matrixclientv3directorylistroomroomid) and [`/directory/list/appservice/{networkId}/{roomId}`](https://spec.matrix.org/v1.4/application-service-api/#put_matrixclientv3directorylistappservicenetworkidroomid). diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7127d5aefc..d52ebada6b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -16,6 +16,8 @@ import logging import string from typing import TYPE_CHECKING, Iterable, List, Optional +from typing_extensions import Literal + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( AuthError, @@ -429,7 +431,10 @@ class DirectoryHandler: return await self.auth.check_can_change_room_list(room_id, requester) async def edit_published_room_list( - self, requester: Requester, room_id: str, visibility: str + self, + requester: Requester, + room_id: str, + visibility: Literal["public", "private"], ) -> None: """Edit the entry of the room in the published room list. @@ -451,9 +456,6 @@ class DirectoryHandler: if requester.is_guest: raise AuthError(403, "Guests cannot edit the published room list") - if visibility not in ["public", "private"]: - raise SynapseError(400, "Invalid visibility setting") - if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( @@ -505,7 +507,11 @@ class DirectoryHandler: await self.store.set_room_is_public(room_id, making_public) async def edit_published_appservice_room_list( - self, appservice_id: str, network_id: str, room_id: str, visibility: str + self, + appservice_id: str, + network_id: str, + room_id: str, + visibility: Literal["public", "private"], ) -> None: """Add or remove a room from the appservice/network specific public room list. @@ -516,9 +522,6 @@ class DirectoryHandler: room_id visibility: either "public" or "private" """ - if visibility not in ["public", "private"]: - raise SynapseError(400, "Invalid visibility setting") - await self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index bc1b18c92d..f17b4c8d22 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -13,15 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple + +from pydantic import StrictStr +from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + parse_and_validate_json_object_from_request, +) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.rest.models import RequestBodyModel from synapse.types import JsonDict, RoomAlias if TYPE_CHECKING: @@ -54,6 +61,12 @@ class ClientDirectoryServer(RestServlet): return 200, res + class PutBody(RequestBodyModel): + # TODO: get Pydantic to validate that this is a valid room id? + room_id: StrictStr + # `servers` is unspecced + servers: Optional[List[StrictStr]] = None + async def on_PUT( self, request: SynapseRequest, room_alias: str ) -> Tuple[int, JsonDict]: @@ -61,31 +74,22 @@ class ClientDirectoryServer(RestServlet): raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM) room_alias_obj = RoomAlias.from_string(room_alias) - content = parse_json_object_from_request(request) - if "room_id" not in content: - raise SynapseError( - 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON - ) + content = parse_and_validate_json_object_from_request(request, self.PutBody) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias_obj.to_string()) - room_id = content["room_id"] - servers = content["servers"] if "servers" in content else None - - logger.debug("Got room_id: %s", room_id) - logger.debug("Got servers: %s", servers) + logger.debug("Got room_id: %s", content.room_id) + logger.debug("Got servers: %s", content.servers) - # TODO(erikj): Check types. - - room = await self.store.get_room(room_id) + room = await self.store.get_room(content.room_id) if room is None: raise SynapseError(400, "Room does not exist") requester = await self.auth.get_user_by_req(request) await self.directory_handler.create_association( - requester, room_alias_obj, room_id, servers + requester, room_alias_obj, content.room_id, content.servers ) return 200, {} @@ -137,16 +141,18 @@ class ClientDirectoryListServer(RestServlet): return 200, {"visibility": "public" if room["is_public"] else "private"} + class PutBody(RequestBodyModel): + visibility: Literal["public", "private"] = "public" + async def on_PUT( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") + content = parse_and_validate_json_object_from_request(request, self.PutBody) await self.directory_handler.edit_published_room_list( - requester, room_id, visibility + requester, room_id, content.visibility ) return 200, {} @@ -163,12 +169,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() + class PutBody(RequestBodyModel): + visibility: Literal["public", "private"] = "public" + async def on_PUT( self, request: SynapseRequest, network_id: str, room_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - return await self._edit(request, network_id, room_id, visibility) + content = parse_and_validate_json_object_from_request(request, self.PutBody) + return await self._edit(request, network_id, room_id, content.visibility) async def on_DELETE( self, request: SynapseRequest, network_id: str, room_id: str @@ -176,7 +184,11 @@ class ClientAppserviceDirectoryListServer(RestServlet): return await self._edit(request, network_id, room_id, "private") async def _edit( - self, request: SynapseRequest, network_id: str, room_id: str, visibility: str + self, + request: SynapseRequest, + network_id: str, + room_id: str, + visibility: Literal["public", "private"], ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not requester.app_service: -- cgit 1.5.1 From 4283bd1cf9c3da2157c3642a7c4f105e9fac2636 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Oct 2022 11:32:11 -0400 Subject: Support filtering the /messages API by relation type (MSC3874). (#14148) Gated behind an experimental configuration flag. --- changelog.d/14148.feature | 1 + synapse/api/filtering.py | 27 +++++- synapse/config/experimental.py | 3 + synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/stream.py | 29 ++++++- tests/api/test_filtering.py | 63 +++++++++++++- tests/rest/client/test_relations.py | 1 - tests/rest/client/test_rooms.py | 145 ++----------------------------- tests/storage/test_stream.py | 118 ++++++++++++++++++------- 9 files changed, 212 insertions(+), 177 deletions(-) create mode 100644 changelog.d/14148.feature (limited to 'synapse') diff --git a/changelog.d/14148.feature b/changelog.d/14148.feature new file mode 100644 index 0000000000..951d0cac80 --- /dev/null +++ b/changelog.d/14148.feature @@ -0,0 +1 @@ +Experimental support for [MSC3874](https://github.com/matrix-org/matrix-spec-proposals/pull/3874). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -117,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4e1fd2bbe7..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -114,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5baffbfe55..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1278,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1294,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a269c477fb..a82c4eed86 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_rel_type(self): + definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_not_rel_type(self): + definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = self.get_success( @@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): events = [ # An event without a relation. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f5c1070b2c..ddf315b894 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ Test that thread replies are still available when the root event is redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3612ebe7b9..71b1637be8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -35,7 +35,6 @@ from synapse.api.constants import ( EventTypes, Membership, PublicRoomsFilterFields, - RelationTypes, RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException @@ -50,6 +49,7 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -2915,149 +2915,20 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - self.helper.join( - room=self.room_id, user=self.second_user_id, tok=self.second_tok - ) - - self.third_user_id = self.register_user("third", "test") - self.third_tok = self.login("third", "test") - self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) - - # An initial event with a relation from second user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 1"}, - tok=self.tok, - ) - self.event_id_1 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.ANNOTATION, - "event_id": self.event_id_1, - "key": "👍", - } - }, - tok=self.second_tok, - ) - - # Another event with a relation from third user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 2"}, - tok=self.tok, - ) - self.event_id_2 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.REFERENCE, - "event_id": self.event_id_2, - } - }, - tok=self.third_tok, - ) - - # An event with no relations. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "No relations"}, - tok=self.tok, - ) - - def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: +class RelationsTestCase(PaginationTestCase): + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" + from_token = self.get_success( + self.from_token.to_string(self.hs.get_datastores().main) + ) channel = self.make_request( "GET", - "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}", access_token=self.tok, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - return channel.json_body["chunk"] - - def test_filter_relation_senders(self) -> None: - # Messages which second user reacted to. - filter = {"related_by_senders": [self.second_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which third user reacted to. - filter = {"related_by_senders": [self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which either user reacted to. - filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_type(self) -> None: - # Messages which have annotations. - filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which have references. - filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which have either annotations or references. - filter = { - "related_by_rel_types": [ - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - ] - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_senders_and_type(self) -> None: - # Messages which second user reacted to. - filter = { - "related_by_senders": [self.second_user_id], - "related_by_rel_types": [RelationTypes.ANNOTATION], - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) + return [ev["event_id"] for ev in channel.json_body["chunk"]] class ContextTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 78663a53fe..34fa810cf6 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -16,7 +16,6 @@ from typing import List from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.types import JsonDict @@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} + config["experimental_features"] = {"msc3874_enabled": True} return config def prepare(self, reactor, clock, homeserver): @@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase): self.third_tok = self.login("third", "test") self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + # Store a token which is after all the room creation events. + self.from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) + ) + # An initial event with a relation from second user. res = self.helper.send_event( room_id=self.room_id, @@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_1 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.second_tok, ) + self.event_id_annotation = res["event_id"] # Another event with a relation from third user. res = self.helper.send_event( @@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_2 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.third_tok, ) + self.event_id_reference = res["event_id"] # An event with no relations. - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "No relations"}, tok=self.tok, ) + self.event_id_none = res["event_id"] - def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) - ) - events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, - from_key=from_token.room_key, + from_key=self.from_token.room_key, to_key=None, - direction="b", + direction="f", limit=10, event_filter=Filter(self.hs, filter), ) ) - return events + return [ev.event_id for ev in events] def test_filter_relation_senders(self): # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which third user reacted to. filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which either user reacted to. filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_type(self): # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which have references. filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which have either annotations or references. filter = { @@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase): ] } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. @@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase): "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) def test_duplicate_relation(self): """An event should only be returned once if there are multiple relations to it.""" @@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase): filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) + + def test_filter_rel_types(self) -> None: + # Messages which are annotations. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_annotation]) + + # Messages which are references. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_reference]) + + # Messages which are either annotations or references. + filter = { + "org.matrix.msc3874.rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertCountEqual( + chunk, + [self.event_id_annotation, self.event_id_reference], + ) + + def test_filter_not_rel_types(self) -> None: + # Messages which are not annotations. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_2, + self.event_id_reference, + self.event_id_none, + ], + ) + + # Messages which are not references. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_annotation, + self.event_id_2, + self.event_id_none, + ], + ) + + # Messages which are neither annotations or references. + filter = { + "org.matrix.msc3874.not_rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) -- cgit 1.5.1 From 2c63cdcc3f1aa4625e947de3c23e0a8133c61286 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 17 Oct 2022 16:02:39 -0500 Subject: Add debug logs to figure out why an event was filtered (#14095) Spawned while investigating https://github.com/matrix-org/synapse/issues/13944 This way we might get some more context whenever an `403 Forbidden - body: {"errcode":"M_FORBIDDEN","error":"You don't have permission to access that event."}` error is produced. `log_config.yaml` ```yaml loggers: synapse: level: INFO synapse.visibility: level: DEBUG ``` --- changelog.d/14095.misc | 1 + synapse/visibility.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14095.misc (limited to 'synapse') diff --git a/changelog.d/14095.misc b/changelog.d/14095.misc new file mode 100644 index 0000000000..3483201d5f --- /dev/null +++ b/changelog.d/14095.misc @@ -0,0 +1 @@ +Add debug logs to figure out why an event was filtered out of the client response. diff --git a/synapse/visibility.py b/synapse/visibility.py index c4048d2477..40a9c5b53f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -84,7 +84,15 @@ async def filter_events_for_client( """ # Filter out events that have been soft failed so that we don't relay them # to clients. + events_before_filtering = events events = [e for e in events if not e.internal_metadata.is_soft_failed()] + if len(events_before_filtering) != len(events): + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "filter_events_for_client: Filtered out soft-failed events: Before=%s, After=%s", + [event.event_id for event in events_before_filtering], + [event.event_id for event in events], + ) types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) @@ -301,6 +309,10 @@ def _check_client_allowed_to_see_event( _check_filter_send_to_client(event, clock, retention_policy, sender_ignored) == _CheckFilter.DENIED ): + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because `_check_filter_send_to_client` returned `_CheckFilter.DENIED`", + event.event_id, + ) return None if event.event_id in always_include_ids: @@ -312,9 +324,17 @@ def _check_client_allowed_to_see_event( # for out-of-band membership events (eg, incoming invites, or rejections of # said invite) for the user themselves. if event.type == EventTypes.Member and event.state_key == user_id: - logger.debug("Returning out-of-band-membership event %s", event) + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Returning out-of-band-membership event %s", + event.event_id, + event, + ) return event + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because it's an outlier", + event.event_id, + ) return None if state is None: @@ -337,11 +357,21 @@ def _check_client_allowed_to_see_event( membership_result = _check_membership(user_id, event, visibility, state, is_peeking) if not membership_result.allowed: + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because the user can't see the event because of their membership, membership_result.allowed=%s membership_result.joined=%s", + event.event_id, + membership_result.allowed, + membership_result.joined, + ) return None # If the sender has been erased and the user was not joined at the time, we # must only return the redacted form. if sender_erased and not membership_result.joined: + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Returning pruned event because `sender_erased` and the user was not joined at the time", + event.event_id, + ) event = prune_event(event) return event -- cgit 1.5.1 From 828b5502cfdf4f1b20750941714ce95cdb242f0d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:33:21 +0100 Subject: Remove `_get_events_cache` check optimisation from `_have_seen_events_dict` (#14161) --- changelog.d/14161.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 31 +++++++++------------- tests/storage/databases/main/test_events_worker.py | 12 --------- 3 files changed, 14 insertions(+), 30 deletions(-) create mode 100644 changelog.d/14161.bugfix (limited to 'synapse') diff --git a/changelog.d/14161.bugfix b/changelog.d/14161.bugfix new file mode 100644 index 0000000000..aed4d9e386 --- /dev/null +++ b/changelog.d/14161.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.30.0 where purging and rejoining a room without restarting in-between would result in a broken room. \ No newline at end of file diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d4104462b5..cfd4780add 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1502,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1524,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 32a798d74b..5773172ab8 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -90,18 +90,6 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_query_via_event_cache(self): - # fetch an event into the event cache - self.get_success(self.store.get_event(self.event_ids[0])) - - # looking it up should now cause no db hits - with LoggingContext(name="test") as ctx: - res = self.get_success( - self.store.have_seen_events(self.room_id, [self.event_ids[0]]) - ) - self.assertEqual(res, {self.event_ids[0]}) - self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_persisting_event_invalidates_cache(self): """ Test to make sure that the `have_seen_event` cache -- cgit 1.5.1 From dc02d9f8c54576d4b41ce51a2704fdd43b582d66 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:33:35 +0100 Subject: Avoid checking the event cache when backfilling events (#14164) --- changelog.d/14164.bugfix | 1 + synapse/handlers/federation_event.py | 47 ++++++++--- synapse/storage/databases/main/events_worker.py | 2 +- tests/handlers/test_federation.py | 105 +++++++++++++++++++++++- 4 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 changelog.d/14164.bugfix (limited to 'synapse') diff --git a/changelog.d/14164.bugfix b/changelog.d/14164.bugfix new file mode 100644 index 0000000000..aed4d9e386 --- /dev/null +++ b/changelog.d/14164.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.30.0 where purging and rejoining a room without restarting in-between would result in a broken room. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4300e8dd40..06e41b5cc0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -798,9 +798,42 @@ class FederationEventHandler: ], ) + # Check if we already any of these have these events. + # Note: we currently make a lookup in the database directly here rather than + # checking the event cache, due to: + # https://github.com/matrix-org/synapse/issues/13476 + existing_events_map = await self._store._get_events_from_db( + [event.event_id for event in events] + ) + + new_events = [] + for event in events: + event_id = event.event_id + + # If we've already seen this event ID... + if event_id in existing_events_map: + existing_event = existing_events_map[event_id] + + # ...and the event itself was not previously stored as an outlier... + if not existing_event.event.internal_metadata.is_outlier(): + # ...then there's no need to persist it. We have it already. + logger.info( + "_process_pulled_event: Ignoring received event %s which we " + "have already seen", + event.event_id, + ) + continue + + # While we have seen this event before, it was stored as an outlier. + # We'll now persist it as a non-outlier. + logger.info("De-outliering event %s", event_id) + + # Continue on with the events that are new to us. + new_events.append(event) + # We want to sort these by depth so we process them and # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) + sorted_events = sorted(new_events, key=lambda x: x.depth) for ev in sorted_events: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) @@ -852,18 +885,6 @@ class FederationEventHandler: event_id = event.event_id - existing = await self._store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "_process_pulled_event: Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - try: self._sanity_check_event(event) except SynapseError as err: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cfd4780add..7bc7f2f33e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -374,7 +374,7 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - The event, or None if the event was not found. + The event, or None if the event was not found and allow_none is `True`. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 745750b1d7..d00c69c229 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -19,7 +19,13 @@ from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + NotFoundError, + SynapseError, +) from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.federation.federation_base import event_from_pdu_json @@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.util import Clock from synapse.util.stringutils import random_string @@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) self.get_success(d) + def test_backfill_ignores_known_events(self) -> None: + """ + Tests that events that we already know about are ignored when backfilling. + """ + # Set up users + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # Create a room to backfill events into + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # Build an event to backfill + event = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"body": "hello world", "msgtype": "m.text"}, + "room_id": room_id, + "sender": other_user, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + # Ensure the event is not already in the DB + self.get_failure( + self.store.get_event(event.event_id), + NotFoundError, + ) + + # Backfill the event and check that it has entered the DB. + + # We mock out the FederationClient.backfill method, to pretend that a remote + # server has returned our fake event. + federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) + self.hs.get_federation_client().backfill = federation_client_backfill_mock + + # We also mock the persist method with a side effect of itself. This allows us + # to track when it has been called while preserving its function. + persist_events_and_notify_mock = Mock( + side_effect=self.hs.get_federation_event_handler().persist_events_and_notify + ) + self.hs.get_federation_event_handler().persist_events_and_notify = ( + persist_events_and_notify_mock + ) + + # Small side-tangent. We populate the event cache with the event, even though + # it is not yet in the DB. This is an invalid scenario that can currently occur + # due to not properly invalidating the event cache. + # See https://github.com/matrix-org/synapse/issues/13476. + # + # As a result, backfill should not rely on the event cache to check whether + # we already have an event in the DB. + # TODO: Remove this bit when the event cache is properly invalidated. + cache_entry = EventCacheEntry( + event=event, + redacted_event=None, + ) + self.store._get_event_cache.set_local((event.event_id,), cache_entry) + + # We now call FederationEventHandler.backfill (a separate method) to trigger + # a backfill request. It should receive the fake event. + self.get_success( + self.hs.get_federation_event_handler().backfill( + other_user, + room_id, + limit=10, + extremities=[], + ) + ) + + # Check that our fake event was persisted. + persist_events_and_notify_mock.assert_called_once() + persist_events_and_notify_mock.reset_mock() + + # Now we repeat the backfill, having the homeserver receive the fake event + # again. + self.get_success( + self.hs.get_federation_event_handler().backfill( + other_user, + room_id, + limit=10, + extremities=[], + ), + ) + + # This time, we expect no event persistence to have occurred, as we already + # have this event. + persist_events_and_notify_mock.assert_not_called() + @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) -- cgit 1.5.1 From c3a4780080a5bcb04132283c0f32f7452655792a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 18 Oct 2022 12:33:18 +0100 Subject: When restarting a partial join resync, prioritise the server which actioned a partial join (#14126) --- changelog.d/14126.misc | 1 + synapse/handlers/device.py | 5 +- synapse/handlers/federation.py | 57 +++++++++++++--------- synapse/storage/database.py | 2 +- synapse/storage/databases/main/room.py | 43 +++++++++++++--- .../delta/73/09partial_joined_via_destination.sql | 18 +++++++ 6 files changed, 95 insertions(+), 31 deletions(-) create mode 100644 changelog.d/14126.misc create mode 100644 synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql (limited to 'synapse') diff --git a/changelog.d/14126.misc b/changelog.d/14126.misc new file mode 100644 index 0000000000..30b3482fbd --- /dev/null +++ b/changelog.d/14126.misc @@ -0,0 +1 @@ +Faster joins: prioritise the server we joined by when restarting a partial join resync. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f9cc5bddbc..c597639a7f 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -937,7 +937,10 @@ class DeviceListUpdater: # Check if we are partially joining any rooms. If so we need to store # all device list updates so that we can handle them correctly once we # know who is in the room. - partial_rooms = await self.store.get_partial_state_rooms_and_servers() + # TODO(faster joins): this fetches and processes a bunch of data that we don't + # use. Could be replaced by a tighter query e.g. + # SELECT EXISTS(SELECT 1 FROM partial_state_rooms) + partial_rooms = await self.store.get_partial_state_room_resync_info() if partial_rooms: await self.store.add_remote_device_list_to_pending( user_id, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5f7e0a1f79..ccc045d36f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -632,6 +632,7 @@ class FederationHandler: room_id=room_id, servers=ret.servers_in_room, device_lists_stream_id=self.store.get_device_stream_token(), + joined_via=origin, ) try: @@ -1615,13 +1616,13 @@ class FederationHandler: """Resumes resyncing of all partial-state rooms after a restart.""" assert not self.config.worker.worker_app - partial_state_rooms = await self.store.get_partial_state_rooms_and_servers() - for room_id, servers_in_room in partial_state_rooms.items(): + partial_state_rooms = await self.store.get_partial_state_room_resync_info() + for room_id, resync_info in partial_state_rooms.items(): run_as_background_process( desc="sync_partial_state_room", func=self._sync_partial_state_room, - initial_destination=None, - other_destinations=servers_in_room, + initial_destination=resync_info.joined_via, + other_destinations=resync_info.servers_in_room, room_id=room_id, ) @@ -1650,28 +1651,12 @@ class FederationHandler: # really leave, that might mean we have difficulty getting the room state over # federation. # https://github.com/matrix-org/synapse/issues/12802 - # - # TODO(faster_joins): we need some way of prioritising which homeservers in - # `other_destinations` to try first, otherwise we'll spend ages trying dead - # homeservers for large rooms. - # https://github.com/matrix-org/synapse/issues/12999 - - if initial_destination is None and len(other_destinations) == 0: - raise ValueError( - f"Cannot resync state of {room_id}: no destinations provided" - ) # Make an infinite iterator of destinations to try. Once we find a working # destination, we'll stick with it until it flakes. - destinations: Collection[str] - if initial_destination is not None: - # Move `initial_destination` to the front of the list. - destinations = list(other_destinations) - if initial_destination in destinations: - destinations.remove(initial_destination) - destinations = [initial_destination] + destinations - else: - destinations = other_destinations + destinations = _prioritise_destinations_for_partial_state_resync( + initial_destination, other_destinations, room_id + ) destination_iter = itertools.cycle(destinations) # `destination` is the current remote homeserver we're pulling from. @@ -1769,3 +1754,29 @@ class FederationHandler: room_id, destination, ) + + +def _prioritise_destinations_for_partial_state_resync( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, +) -> Collection[str]: + """Work out the order in which we should ask servers to resync events. + + If an `initial_destination` is given, it takes top priority. Otherwise + all servers are treated equally. + + :raises ValueError: if no destination is provided at all. + """ + if initial_destination is None and len(other_destinations) == 0: + raise ValueError(f"Cannot resync state of {room_id}: no destinations provided") + + if initial_destination is None: + return other_destinations + + # Move `initial_destination` to the front of the list. + destinations = list(other_destinations) + if initial_destination in destinations: + destinations.remove(initial_destination) + destinations = [initial_destination] + destinations + return destinations diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7bb21f8f81..4717c9728a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1658,7 +1658,7 @@ class DatabasePool: table: string giving the table name keyvalues: dict of column names and values to select the row with retcol: string giving the name of the column to return - allow_none: If true, return None instead of failing if the SELECT + allow_none: If true, return None instead of raising StoreError if the SELECT statement returns no rows desc: description of the transaction, for logging and metrics """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index e41c99027a..7d97f8f60e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -97,6 +97,12 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PartialStateResyncInfo: + joined_via: Optional[str] + servers_in_room: List[str] = attr.ib(factory=list) + + class RoomWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1160,17 +1166,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_partial_state_servers_at_join", ) - async def get_partial_state_rooms_and_servers( + async def get_partial_state_room_resync_info( self, - ) -> Mapping[str, Collection[str]]: - """Get all rooms containing events with partial state, and the servers known - to be in the room. + ) -> Mapping[str, PartialStateResyncInfo]: + """Get all rooms containing events with partial state, and the information + needed to restart a "resync" of those rooms. Returns: A dictionary of rooms with partial state, with room IDs as keys and lists of servers in rooms as values. """ - room_servers: Dict[str, List[str]] = {} + room_servers: Dict[str, PartialStateResyncInfo] = {} + + rows = await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ) + + for row in rows: + room_id = row["room_id"] + joined_via = row["joined_via"] + room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) rows = await self.db_pool.simple_select_list( "partial_state_rooms_servers", @@ -1182,7 +1200,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): for row in rows: room_id = row["room_id"] server_name = row["server_name"] - room_servers.setdefault(room_id, []).append(server_name) + entry = room_servers.get(room_id) + if entry is None: + # There is a foreign key constraint which enforces that every room_id in + # partial_state_rooms_servers appears in partial_state_rooms. So we + # expect `entry` to be non-null. (This reasoning fails if we've + # partial-joined between the two SELECTs, but this is unlikely to happen + # in practice.) + continue + entry.servers_in_room.append(server_name) return room_servers @@ -1827,6 +1853,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: """Mark the given room as containing events with partial state. @@ -1842,6 +1869,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): servers: other servers known to be in the room device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. + joined_via: the server name we requested a partial join from. """ await self.db_pool.runInteraction( "store_partial_state_room", @@ -1849,6 +1877,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, servers, device_lists_stream_id, + joined_via, ) def _store_partial_state_room_txn( @@ -1857,6 +1886,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: DatabasePool.simple_insert_txn( txn, @@ -1866,6 +1896,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "device_lists_stream_id": device_lists_stream_id, # To be updated later once the join event is persisted. "join_event_id": None, + "joined_via": joined_via, }, ) DatabasePool.simple_insert_many_txn( diff --git a/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql b/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql new file mode 100644 index 0000000000..066d602b18 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql @@ -0,0 +1,18 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- When we resync partial state, we prioritise doing so using the server we +-- partial-joined from. To do this we need to record which server that was! +ALTER TABLE partial_state_rooms ADD COLUMN joined_via TEXT; -- cgit 1.5.1 From 8e50299d8b112364b011ca8f89bc19a97e9622ec Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 18 Oct 2022 13:59:04 +0100 Subject: Fix `track_memory_usage` on poetry-core 1.3.x installations (#14221) * Fix `track_memory_usage` on poetry-core 1.3.x installations The same kind of problem as discussed in #14085: 1. we defined an extra with an underscore 2. we look it up at runtime with an underscore 3. but poetry-core 1.3.x. installs it with a dash, causing (2) to fail. Fix by using a dash everywhere. * Changelog --- changelog.d/14221.misc | 1 + pyproject.toml | 4 ++-- synapse/config/cache.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14221.misc (limited to 'synapse') diff --git a/changelog.d/14221.misc b/changelog.d/14221.misc new file mode 100644 index 0000000000..fe7afac245 --- /dev/null +++ b/changelog.d/14221.misc @@ -0,0 +1 @@ +Rename the `cache_memory` extra to `cache-memory`, for compatability with poetry-core 1.3.0 and [PEP 685](https://peps.python.org/pep-0685/). From-source installations using this extra will need to install using the new name. diff --git a/pyproject.toml b/pyproject.toml index 7fbbc08915..8bc24c556a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,7 +227,7 @@ jwt = ["authlib"] # (if it is not installed, we fall back to slow code.) redis = ["txredisapi", "hiredis"] # Required to use experimental `caches.track_memory_usage` config option. -cache_memory = ["pympler"] +cache-memory = ["pympler"] test = ["parameterized", "idna"] # The duplication here is awful. I hate hate hate hate hate it. However, for now I want @@ -258,7 +258,7 @@ all = [ "jaeger-client", "opentracing", # redis "txredisapi", "hiredis", - # cache_memory + # cache-memory "pympler", # omitted: # - test: it's useful to have this separate from dev deps in the olddeps job diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 2db8cfb005..eb4194a5a9 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -159,7 +159,7 @@ class CacheConfig(Config): self.track_memory_usage = cache_config.get("track_memory_usage", False) if self.track_memory_usage: - check_requirements("cache_memory") + check_requirements("cache-memory") expire_caches = cache_config.get("expire_caches", True) cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m") -- cgit 1.5.1 From dbf18f514ea5d2539ba3148049eae5a6793f1d60 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Oct 2022 10:55:41 -0400 Subject: Update the thread_id right before use (in case the bg update hasn't finished) (#14222) This avoids running a forced-update of a null thread_id rows. An index is added (in the background) to hopefully make this easier in the future. --- changelog.d/14222.feature | 1 + .../storage/databases/main/event_push_actions.py | 103 +++++++++++++++++++++ .../delta/73/06thread_notifications_backfill.sql | 29 ------ .../73/06thread_notifications_thread_id_idx.sql | 23 +++++ .../07thread_notifications_not_null.sql.postgres | 19 ---- .../73/07thread_notifications_not_null.sql.sqlite | 101 -------------------- 6 files changed, 127 insertions(+), 149 deletions(-) create mode 100644 changelog.d/14222.feature delete mode 100644 synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql create mode 100644 synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql delete mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres delete mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite (limited to 'synapse') diff --git a/changelog.d/14222.feature b/changelog.d/14222.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14222.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f070e6e88a..b283ab0f9c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -294,6 +294,44 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._background_backfill_thread_id, ) + # Indexes which will be used to quickly make the thread_id column non-null. + self.db_pool.updates.register_background_index_update( + "event_push_actions_thread_id_null", + index_name="event_push_actions_thread_id_null", + table="event_push_actions", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_thread_id_null", + index_name="event_push_summary_thread_id_null", + table="event_push_summary", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates the event_push_actions and event_push_summary tables. + self._clock.call_later(0.0, self._check_event_push_backfill_thread_id) + self._event_push_backfill_thread_id_done = False + + @wrap_as_background_process("check_event_push_backfill_thread_id") + async def _check_event_push_backfill_thread_id(self) -> None: + """ + Has thread_id finished backfilling? + + If not, we need to just-in-time update it so the queries work. + """ + done = await self.db_pool.updates.has_completed_background_update( + "event_push_backfill_thread_id" + ) + + if done: + self._event_push_backfill_thread_id_done = True + else: + # Reschedule to run. + self._clock.call_later(15.0, self._check_event_push_backfill_thread_id) + async def _background_backfill_thread_id( self, progress: JsonDict, batch_size: int ) -> int: @@ -526,6 +564,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # First we pull the counts from the summary table. # # We check that `last_receipt_stream_ordering` matches the stream ordering of the @@ -1341,6 +1398,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (room_id, user_id, stream_ordering, *thread_args), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. unread_counts = self._get_notif_unread_count_for_user_room( @@ -1475,6 +1551,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas rotate_to_stream_ordering: The new maximum event stream ordering to summarise. """ + # Ensure that any new actions have an updated thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL + """, + (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # XXX Do we need to update summaries here too? + # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, thread_id, @@ -1537,6 +1626,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # Ensure that any updated threads have the proper thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute_batch( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + [ + (MAIN_TIMELINE, room_id, user_id) + for user_id, room_id, _ in summaries + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql deleted file mode 100644 index 0ffde9bbeb..0000000000 --- a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2022 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Forces the background updates from 06thread_notifications.sql to run in the --- foreground as code will now require those to be "done". - -DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; - --- Overwrite any null thread_id columns. -UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; -UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; -UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL; - --- Do not run the event_push_summary_unique_index job if it is pending; the --- thread_id field will be made required. -DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; -DROP INDEX IF EXISTS event_push_summary_unique_index; diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql new file mode 100644 index 0000000000..8b3c636594 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Allow there to be multiple summaries per user/room. +DROP INDEX IF EXISTS event_push_summary_unique_index; + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7306, 'event_push_actions_thread_id_null', '{}', 'event_push_backfill_thread_id'); + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7306, 'event_push_summary_thread_id_null', '{}', 'event_push_backfill_thread_id'); diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres deleted file mode 100644 index 33674f8c62..0000000000 --- a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2022 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- The columns can now be made non-nullable. -ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL; -ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL; -ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite deleted file mode 100644 index 5322ad77a4..0000000000 --- a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2022 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- SQLite doesn't support modifying columns to an existing table, so it must --- be recreated. - --- Create the new tables. -CREATE TABLE event_push_actions_staging_new ( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - actions TEXT NOT NULL, - notif SMALLINT NOT NULL, - highlight SMALLINT NOT NULL, - unread SMALLINT, - thread_id TEXT NOT NULL, - inserted_ts BIGINT -); - -CREATE TABLE event_push_actions_new ( - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - profile_tag VARCHAR(32), - actions TEXT NOT NULL, - topological_ordering BIGINT, - stream_ordering BIGINT, - notif SMALLINT, - highlight SMALLINT, - unread SMALLINT, - thread_id TEXT NOT NULL, - CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) -); - -CREATE TABLE event_push_summary_new ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - notif_count BIGINT NOT NULL, - stream_ordering BIGINT NOT NULL, - unread_count BIGINT, - last_receipt_stream_ordering BIGINT, - thread_id TEXT NOT NULL -); - --- Swap the indexes. -DROP INDEX IF EXISTS event_push_actions_staging_id; -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id); - -DROP INDEX IF EXISTS event_push_actions_room_id_user_id; -DROP INDEX IF EXISTS event_push_actions_rm_tokens; -DROP INDEX IF EXISTS event_push_actions_stream_ordering; -DROP INDEX IF EXISTS event_push_actions_u_highlight; -DROP INDEX IF EXISTS event_push_actions_highlights_index; -CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id); -CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering ); -CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id ); -CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering); -CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering); - --- Copy the data. -INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) - SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts - FROM event_push_actions_staging; - -INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) - SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id - FROM event_push_actions; - -INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) - SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id - FROM event_push_summary; - --- Drop the old tables. -DROP TABLE event_push_actions_staging; -DROP TABLE event_push_actions; -DROP TABLE event_push_summary; - --- Rename the tables. -ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; -ALTER TABLE event_push_actions_new RENAME TO event_push_actions; -ALTER TABLE event_push_summary_new RENAME TO event_push_summary; - --- Re-run background updates from 72/02event_push_actions_index.sql and --- 72/06thread_notifications.sql. -INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (7307, 'event_push_summary_unique_index2', '{}') - ON CONFLICT (update_name) DO NOTHING; -INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (7307, 'event_push_actions_stream_highlight_index', '{}') - ON CONFLICT (update_name) DO NOTHING; -- cgit 1.5.1 From 4eaf3eb840b8cfa78d970216c74fc128495f08a5 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 18 Oct 2022 16:52:25 +0100 Subject: Implementation of HTTP 307 response for MSC3886 POST endpoint (#14018) Co-authored-by: reivilibre Co-authored-by: Andrew Morgan --- changelog.d/14018.feature | 1 + synapse/config/experimental.py | 7 +- synapse/config/server.py | 4 ++ synapse/handlers/sso.py | 2 +- synapse/http/server.py | 48 ++++++++++--- synapse/http/site.py | 3 + synapse/rest/__init__.py | 2 + synapse/rest/client/rendezvous.py | 74 +++++++++++++++++++ synapse/rest/client/versions.py | 3 + synapse/rest/key/v2/local_key_resource.py | 4 +- synapse/rest/synapse/client/new_user_consent.py | 3 +- synapse/rest/well_known.py | 3 +- tests/logging/test_terse_json.py | 1 + tests/rest/client/test_rendezvous.py | 45 ++++++++++++ tests/server.py | 8 ++- tests/test_server.py | 94 ++++++++++++++++++------- 16 files changed, 257 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14018.feature create mode 100644 synapse/rest/client/rendezvous.py create mode 100644 tests/rest/client/test_rendezvous.py (limited to 'synapse') diff --git a/changelog.d/14018.feature b/changelog.d/14018.feature new file mode 100644 index 0000000000..c8454607eb --- /dev/null +++ b/changelog.d/14018.feature @@ -0,0 +1 @@ +Support for redirecting to an implementation of a [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) HTTP rendezvous service. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f9a49451d8..4009add01d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import attr @@ -120,3 +120,8 @@ class ExperimentalConfig(Config): # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) + + # MSC3886: Simple client rendezvous capability + self.msc3886_endpoint: Optional[str] = experimental.get( + "msc3886_endpoint", None + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index f2353ce5fb..ec46ca63ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,6 +207,9 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None + # If true, the listener will return CORS response headers compatible with MSC3886: + # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 + experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -935,6 +938,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), + experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e035677b8a..5943f08e91 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -874,7 +874,7 @@ class SsoHandler: ) async def handle_terms_accepted( - self, request: Request, session_id: str, terms_version: str + self, request: SynapseRequest, session_id: str, terms_version: str ) -> None: """Handle a request to the new-user 'consent' endpoint diff --git a/synapse/http/server.py b/synapse/http/server.py index bcbfac2c9f..b26e34bceb 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -339,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -598,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -763,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -859,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -870,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -942,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/site.py b/synapse/http/site.py index 55a6afce35..3dbd541fed 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -82,6 +82,7 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -622,6 +623,8 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9a2ab99ede..28542cd774 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client import ( receipts, register, relations, + rendezvous, report_event, room, room_batch, @@ -132,3 +133,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py new file mode 100644 index 0000000000..89176b1ffa --- /dev/null +++ b/synapse/rest/client/rendezvous.py @@ -0,0 +1,74 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http.client import TEMPORARY_REDIRECT +from typing import TYPE_CHECKING, Optional + +from synapse.http.server import HttpServer, respond_with_redirect +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RendezvousServlet(RestServlet): + """ + This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) + simple client rendezvous capability that is used by the "Sign in with QR" functionality. + + This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. + + A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. + + Request: + + POST /rendezvous HTTP/1.1 + Content-Type: ... + + ... + + Response: + + HTTP/1.1 307 + Location: + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint + assert ( + redirection_target is not None + ), "Servlet is only registered if there is a redirection target" + self.endpoint = redirection_target.encode("utf-8") + + async def on_POST(self, request: SynapseRequest) -> None: + respond_with_redirect( + request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True + ) + + # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3886_endpoint is not None: + RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4b87ee978a..9b1b72c68a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -116,6 +116,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3881": self.config.experimental.msc3881_enabled, # Adds support for filtering /messages by event relation. "org.matrix.msc3874": self.config.experimental.msc3874_enabled, + # Adds support for simple HTTP rendezvous as per MSC3886 + "org.matrix.msc3886": self.config.experimental.msc3886_endpoint + is not None, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 0c9f042c84..095993415c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -20,9 +20,9 @@ from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.http.server import respond_with_json_bytes +from synapse.http.site import SynapseRequest from synapse.types import JsonDict if TYPE_CHECKING: @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> Optional[int]: + def render_GET(self, request: SynapseRequest) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 1c1c7b3613..22784157e6 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.types import UserID from synapse.util.templates import build_jinja_env @@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource): html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6f7ac54c65..e2174fdfea 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.stringutils import parse_server_name @@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: Request) -> bytes: + def render_GET(self, request: SynapseRequest) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 96f399b7ab..0b0d8737c1 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.site_tag = "test-site" site.server_version_string = "Server v1" site.reactor = Mock() + site.experimental_cors_msc3886 = False request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py new file mode 100644 index 0000000000..ad00a476e1 --- /dev/null +++ b/tests/rest/client/test_rendezvous.py @@ -0,0 +1,45 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.client import rendezvous +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + +endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" + + +class RendezvousServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + rendezvous.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + return self.hs + + def test_disabled(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) + def test_redirect(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 307) + self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) diff --git a/tests/server.py b/tests/server.py index c447d5e4c4..8b1d186219 100644 --- a/tests/server.py +++ b/tests/server.py @@ -266,7 +266,12 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource, reactor: IReactorTime): + def __init__( + self, + resource: IResource, + reactor: IReactorTime, + experimental_cors_msc3886: bool = False, + ): """ Args: @@ -274,6 +279,7 @@ class FakeSite: """ self._resource = resource self.reactor = reactor + self.experimental_cors_msc3886 = experimental_cors_msc3886 def getResourceFor(self, request): return self._resource diff --git a/tests/test_server.py b/tests/test_server.py index 7c66448245..2d9a0257d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase): self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource()) - def _make_request(self, method: bytes, path: bytes) -> FakeChannel: + def _make_request( + self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False + ) -> FakeChannel: """Create a request from the method/path and return a channel with the response.""" # Create a site and query for the resource. site = SynapseSite( "test", "site_tag", - parse_listener_def(0, {"type": "http", "port": 0}), + parse_listener_def( + 0, + { + "type": "http", + "port": 0, + "experimental_cors_msc3886": experimental_cors_msc3886, + }, + ), self.resource, "1.0", max_request_body_size=4096, @@ -239,25 +248,58 @@ class OptionsResourceTests(unittest.TestCase): channel = make_request(self.reactor, site, method, path, shorthand=False) return channel + def _check_cors_standard_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [b"X-Requested-With, Content-Type, Authorization, Date"], + "has correct CORS Headers header", + ) + + def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [ + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match" + ], + "has correct CORS Headers header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"), + [b"ETag, Location, X-Max-Bytes"], + "has correct CORS Expose Headers header", + ) + def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", - ) + self._check_cors_standard_headers(channel) def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" @@ -265,19 +307,17 @@ class OptionsResourceTests(unittest.TestCase): self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", + self._check_cors_standard_headers(channel) + + def test_known_options_request_msc3886(self) -> None: + """An OPTIONS requests to an known URL still returns 204 No Content.""" + channel = self._make_request( + b"OPTIONS", b"/res/", experimental_cors_msc3886=True ) + self.assertEqual(channel.code, 204) + self.assertNotIn("body", channel.result) + + self._check_cors_msc3886_headers(channel) def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" -- cgit 1.5.1 From 847e2393f3198b88809c9b99de5c681efbf1c92e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 18 Oct 2022 09:58:47 -0700 Subject: Prepatory work for adding power level event to batched events (#14214) --- changelog.d/14214.misc | 1 + synapse/event_auth.py | 19 ++++++++++++++++++- synapse/handlers/event_auth.py | 18 +++++++++++++----- synapse/handlers/federation.py | 12 +++++------- synapse/handlers/message.py | 10 +++++++++- synapse/handlers/room.py | 4 +--- 6 files changed, 47 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14214.misc (limited to 'synapse') diff --git a/changelog.d/14214.misc b/changelog.d/14214.misc new file mode 100644 index 0000000000..102928b575 --- /dev/null +++ b/changelog.d/14214.misc @@ -0,0 +1 @@ +When authenticating batched events, check for auth events in batch as well as DB. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c7d5ef92fc..bab31e33c5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -15,7 +15,18 @@ import logging import typing -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -134,6 +145,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: async def check_state_independent_auth_rules( store: _EventSourceStore, event: "EventBase", + batched_auth_events: Optional[Mapping[str, "EventBase"]] = None, ) -> None: """Check that an event complies with auth rules that are independent of room state @@ -143,6 +155,8 @@ async def check_state_independent_auth_rules( Args: store: the datastore; used to fetch the auth events for validation event: the event being checked. + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event Raises: AuthError if the checks fail @@ -162,6 +176,9 @@ async def check_state_independent_auth_rules( redact_behaviour=EventRedactBehaviour.as_is, allow_rejected=True, ) + if batched_auth_events: + auth_events.update(batched_auth_events) + room_id = event.room_id auth_dict: MutableStateMap[str] = {} expected_auth_types = auth_types_for_event(event.room_version, event) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 8249ca1ed2..3bbad0271b 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Optional, Union +from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,6 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext from synapse.types import StateMap, get_domain_from_id if TYPE_CHECKING: @@ -51,12 +50,21 @@ class EventAuthHandler: async def check_auth_rules_from_context( self, event: EventBase, - context: EventContext, + batched_auth_events: Optional[Mapping[str, EventBase]] = None, ) -> None: - """Check an event passes the auth rules at its own auth events""" - await check_state_independent_auth_rules(self._store, event) + """Check an event passes the auth rules at its own auth events + Args: + event: event to be authed + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event + """ + await check_state_independent_auth_rules( + self._store, event, batched_auth_events + ) auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) + if batched_auth_events: + auth_events_by_id.update(batched_auth_events) check_state_dependent_auth_rules(event, auth_events_by_id.values()) def compute_auth_events( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ccc045d36f..275a37a575 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -942,7 +942,7 @@ class FederationHandler: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) return event async def on_invite_request( @@ -1123,7 +1123,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) raise e @@ -1182,7 +1182,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) raise e @@ -1348,9 +1348,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -1400,7 +1398,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4e55ebba0b..15b828dd74 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1360,8 +1360,16 @@ class EventCreationHandler: else: try: validate_event_for_room_version(event) + # If we are persisting a batch of events the event(s) needed to auth the + # current event may be part of the batch and will not be in the DB yet + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + batched_auth_events = {} + for event_id in event.auth_event_ids(): + auth_event = event_id_to_event.get(event_id) + if auth_event: + batched_auth_events[event_id] = auth_event await self._event_auth_handler.check_auth_rules_from_context( - event, context + event, batched_auth_events ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4e1aacb408..638f54051a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -229,9 +229,7 @@ class RoomCreationHandler: }, ) validate_event_for_room_version(tombstone_event) - await self._event_auth_handler.check_auth_rules_from_context( - tombstone_event, tombstone_context - ) + await self._event_auth_handler.check_auth_rules_from_context(tombstone_event) # Upgrade the room # -- cgit 1.5.1 From 1c777ef1e87d7be39a2b8f6fb119fa4b51e2be4c Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 18 Oct 2022 13:40:50 -0700 Subject: Fix docstring in EventContext (#14145) --- changelog.d/14145.doc | 2 ++ synapse/events/snapshot.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14145.doc (limited to 'synapse') diff --git a/changelog.d/14145.doc b/changelog.d/14145.doc new file mode 100644 index 0000000000..8f876e08fc --- /dev/null +++ b/changelog.d/14145.doc @@ -0,0 +1,2 @@ +Clarify comment on event contexts. + diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index d3c8083e4a..1c0e96bec7 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -65,7 +65,8 @@ class EventContext: None does not necessarily mean that ``state_group`` does not have a prev_group! - If the event is a state event, this is normally the same as ``prev_group``. + If the event is a state event, this is normally the same as + ``state_group_before_event``. If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` will always also be ``None``. -- cgit 1.5.1 From 2a76a7369fc54477185f53f6e81897fa84e24de5 Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Tue, 18 Oct 2022 14:54:27 -0600 Subject: Fix hiding devices names over federation (#10015) And don't include blank opentracing stuff in device list updates. Signed-off-by: Aaron Raimist --- changelog.d/10015.bugfix | 1 + synapse/storage/databases/main/devices.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10015.bugfix (limited to 'synapse') diff --git a/changelog.d/10015.bugfix b/changelog.d/10015.bugfix new file mode 100644 index 0000000000..cbebd97e58 --- /dev/null +++ b/changelog.d/10015.bugfix @@ -0,0 +1 @@ +Prevent device names from appearing in device list updates when `allow_device_name_lookup_over_federation` is `false`. \ No newline at end of file diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 18358eca46..830b076a32 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -539,9 +539,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "device_id": device_id, "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, - "org.matrix.opentracing_context": opentracing_context, } + if opentracing_context != "{}": + result["org.matrix.opentracing_context"] = opentracing_context + prev_id = stream_id if device is not None: @@ -549,7 +551,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if keys: result["keys"] = keys - device_display_name = device.display_name + device_display_name = None + if ( + self.hs.config.federation.allow_device_name_lookup_over_federation + ): + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: -- cgit 1.5.1 From fa8616e65c82367712a7b75c62682a89541b6330 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 18 Oct 2022 19:46:25 -0500 Subject: Fix MSC3030 `/timestamp_to_event` returning `outliers` that it has no idea whether are near a gap or not (#14215) Fix MSC3030 `/timestamp_to_event` endpoint returning `outliers` that it has no idea whether are near a gap or not (and therefore unable to determine whether it's actually the closest event). The reason Synapse doesn't know whether an `outlier` is next to a gap is because our gap checks rely on entries in the `event_edges`, `event_forward_extremeties`, and `event_backward_extremities` tables which is [not the case for `outliers`](https://github.com/matrix-org/synapse/blob/2c63cdcc3f1aa4625e947de3c23e0a8133c61286/docs/development/room-dag-concepts.md#outliers). Also fixes MSC3030 Complement `can_paginate_after_getting_remote_event_from_timestamp_to_event_endpoint` test flake. Although this acted flakey in Complement, if `sync_partial_state` raced and beat us before `/timestamp_to_event`, then even if we retried the failing `/context` request it wouldn't work until we made this Synapse change. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. Fix https://github.com/matrix-org/synapse/issues/13944 ### Why did this fail before? Why was it flakey? Sleuthing the server logs on the [CI failure](https://github.com/matrix-org/synapse/actions/runs/3149623842/jobs/5121449357#step:5:5805), it looks like `hs2:/timestamp_to_event` found `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` event locally. Then when we went and asked for it via `/context`, since it's an `outlier`, it was filtered out of the results -> `You don't have permission to access that event.` This is reproducible when `sync_partial_state` races and persists `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` before we evaluate `get_event_for_timestamp(...)`. To consistently reproduce locally, just add a delay at the [start of `get_event_for_timestamp(...)`](https://github.com/matrix-org/synapse/blob/cb20b885cb4bd1648581dd043a184d86fc8c7a00/synapse/handlers/room.py#L1470-L1496) so it always runs after `sync_partial_state` completes. ```py from twisted.internet import task as twisted_task d = twisted_task.deferLater(self.hs.get_reactor(), 3.5) await d ``` In a run where it passes, on `hs2`, `get_event_for_timestamp(...)` finds a different event locally which is next to a gap and we request from a closer one from `hs1` which gets backfilled. And since the backfilled event is not an `outlier`, it's returned as expected during `/context`. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. --- changelog.d/14215.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 59 ++++++++++++++-------- tests/rest/client/test_rooms.py | 65 +++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 changelog.d/14215.bugfix (limited to 'synapse') diff --git a/changelog.d/14215.bugfix b/changelog.d/14215.bugfix new file mode 100644 index 0000000000..31c109f534 --- /dev/null +++ b/changelog.d/14215.bugfix @@ -0,0 +1 @@ +Fix [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint returning potentially inaccurate closest events with `outliers` present. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7bc7f2f33e..69fea452ad 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1971,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2026,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2112,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2128,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 71b1637be8..716366eb90 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -39,6 +39,8 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, register, room, sync @@ -51,6 +53,7 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import create_event PATH_PREFIX = b"/_matrix/client/api/v1" @@ -3486,3 +3489,65 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") + + +class TimestampLookupTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc3030_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._storage_controllers = self.hs.get_storage_controllers() + + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + def _inject_outlier(self, room_id: str) -> EventBase: + event, _context = self.get_success( + create_event( + self.hs, + room_id=room_id, + type="m.test", + sender="@test_remote_user:remote", + ) + ) + + event.internal_metadata.outlier = True + self.get_success( + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) + ) + ) + return event + + def test_no_outliers(self) -> None: + """ + Test to make sure `/timestamp_to_event` does not return `outlier` events. + We're unable to determine whether an `outlier` is next to a gap so we + don't know whether it's actually the closest event. Instead, let's just + ignore `outliers` with this endpoint. + + This test is really seeing that we choose the non-`outlier` event behind the + `outlier`. Since the gap checking logic considers the latest message in the room + as *not* next to a gap, asking over federation does not come into play here. + """ + room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok) + + outlier_event = self._inject_outlier(room_id) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + access_token=self.room_owner_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Make sure the outlier event is not returned + self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id) -- cgit 1.5.1 From fe50738e597817735aa910e3cd1e13e4792f7d9f Mon Sep 17 00:00:00 2001 From: Finn Date: Wed, 19 Oct 2022 11:08:40 -0700 Subject: let update_synapse_database run on a multi-database configurations (#13422) * Allow sharded database in db migrate script Signed-off-by: Finn Herzfeld * Update changelog.d/13422.bugfix Co-authored-by: Patrick Cloke * Remove check entirely * remove unused import Signed-off-by: Finn Herzfeld Co-authored-by: finn Co-authored-by: Patrick Cloke --- changelog.d/13422.bugfix | 1 + synapse/_scripts/update_synapse_database.py | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) create mode 100644 changelog.d/13422.bugfix mode change 100755 => 100644 synapse/_scripts/update_synapse_database.py (limited to 'synapse') diff --git a/changelog.d/13422.bugfix b/changelog.d/13422.bugfix new file mode 100644 index 0000000000..3a099acbe6 --- /dev/null +++ b/changelog.d/13422.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the `update_synapse_database` script could not be run with multiple databases. Contributed by @thefinn93 @ Beeper. \ No newline at end of file diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py old mode 100755 new mode 100644 index fb1fb83f50..0adf94bba6 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -15,7 +15,6 @@ import argparse import logging -import sys from typing import cast import yaml @@ -100,13 +99,6 @@ def main() -> None: # Load, process and sanity-check the config. hs_config = yaml.safe_load(args.database_config) - if "database" not in hs_config and "databases" not in hs_config: - sys.stderr.write( - "The configuration file must have a 'database' or 'databases' section. " - "See https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#database" - ) - sys.exit(4) - config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") -- cgit 1.5.1 From 0b7830e457359ce651b293c8748bf636973404a9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:38:24 +0000 Subject: Bump flake8-bugbear from 21.3.2 to 22.9.23 (#14042) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Erik Johnston Co-authored-by: David Robertson --- .flake8 | 9 ++++++++- changelog.d/14042.misc | 1 + poetry.lock | 8 ++++---- synapse/storage/databases/main/roommember.py | 4 ++-- synapse/util/caches/deferred_cache.py | 4 ++-- synapse/util/caches/descriptors.py | 2 +- tests/federation/transport/test_client.py | 7 +++---- tests/util/caches/test_descriptors.py | 2 +- 8 files changed, 22 insertions(+), 15 deletions(-) create mode 100644 changelog.d/14042.misc (limited to 'synapse') diff --git a/.flake8 b/.flake8 index acb118c86e..4c6a4d5843 100644 --- a/.flake8 +++ b/.flake8 @@ -8,4 +8,11 @@ # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 +# +# flake8-bugbear runs extra checks. Its error codes are described at +# https://github.com/PyCQA/flake8-bugbear#list-of-warnings +# B019: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks +# B023: Functions defined inside a loop must not use variables redefined in the loop +# B024: Abstract base class with no abstract method. + +ignore=W503,W504,E203,E731,E501,B019,B023,B024 diff --git a/changelog.d/14042.misc b/changelog.d/14042.misc new file mode 100644 index 0000000000..868d55e76a --- /dev/null +++ b/changelog.d/14042.misc @@ -0,0 +1 @@ +Bump flake8-bugbear from 21.3.2 to 22.9.23. diff --git a/poetry.lock b/poetry.lock index ed0b59fbe5..0a2f9ab69e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -260,7 +260,7 @@ pyflakes = ">=2.4.0,<2.5.0" [[package]] name = "flake8-bugbear" -version = "21.3.2" +version = "22.9.23" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." category = "dev" optional = false @@ -271,7 +271,7 @@ attrs = ">=19.2.0" flake8 = ">=3.0.0" [package.extras] -dev = ["black", "coverage", "hypothesis", "hypothesmith"] +dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit"] [[package]] name = "flake8-comprehensions" @@ -1826,8 +1826,8 @@ flake8 = [ {file = "flake8-4.0.1.tar.gz", hash = "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d"}, ] flake8-bugbear = [ - {file = "flake8-bugbear-21.3.2.tar.gz", hash = "sha256:cadce434ceef96463b45a7c3000f23527c04ea4b531d16c7ac8886051f516ca0"}, - {file = "flake8_bugbear-21.3.2-py36.py37.py38-none-any.whl", hash = "sha256:5d6ccb0c0676c738a6e066b4d50589c408dcc1c5bf1d73b464b18b73cd6c05c2"}, + {file = "flake8-bugbear-22.9.23.tar.gz", hash = "sha256:17b9623325e6e0dcdcc80ed9e4aa811287fcc81d7e03313b8736ea5733759937"}, + {file = "flake8_bugbear-22.9.23-py3-none-any.whl", hash = "sha256:cd2779b2b7ada212d7a322814a1e5651f1868ab0d3f24cc9da66169ab8fda474"}, ] flake8-comprehensions = [ {file = "flake8-comprehensions-3.8.0.tar.gz", hash = "sha256:8e108707637b1d13734f38e03435984f6b7854fa6b5a4e34f93e69534be8e521"}, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2ed6ad754f..32e1e983a5 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -707,8 +707,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): # 250 users is pretty arbitrary but the data can be quite large if users # are in many rooms. - for user_ids in batch_iter(user_ids, 250): - all_user_rooms.update(await self._get_rooms_for_users(user_ids)) + for batch_user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids)) return all_user_rooms diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 6425f851ea..bcb1cba362 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -395,8 +395,8 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. - for entry in iterate_tree_cache_entry(entry): - for cb in entry.get_invalidation_callbacks(key): + for iter_entry in iterate_tree_cache_entry(entry): + for cb in iter_entry.get_invalidation_callbacks(key): cb() def invalidate_all(self) -> None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 0391966462..b3c748ef44 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -432,7 +432,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): num_args = cached_method.num_args if num_args != self.num_args: - raise Exception( + raise TypeError( "Number of args (%s) does not match underlying cache_method_name=%s (%s)." % (self.num_args, self.cached_method_name, num_args) ) diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index 0926e0583d..dd4d1b56de 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -17,6 +17,7 @@ from unittest.mock import Mock from synapse.api.room_versions import RoomVersions from synapse.federation.transport.client import SendJoinParser +from synapse.util import ExceptionBundle from tests.unittest import TestCase @@ -121,10 +122,8 @@ class SendJoinParserTestCase(TestCase): # Send half of the data to the parser parser.write(serialisation[: len(serialisation) // 2]) - # Close the parser. There should be _some_ kind of exception, but it need not - # be that RuntimeError directly. E.g. we might want to raise a wrapper - # encompassing multiple errors from multiple coroutines. - with self.assertRaises(Exception): + # Close the parser. There should be _some_ kind of exception. + with self.assertRaises(ExceptionBundle): parser.finish() # In any case, we should have tried to close both coros. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 90861fe522..78fd7b6961 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -1037,5 +1037,5 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() # Make sure this raises an error about the arg mismatch - with self.assertRaises(Exception): + with self.assertRaises(TypeError): obj.list_fn([("foo", "bar")]) -- cgit 1.5.1 From 70b33965065f0e93eaba68e371896149c9405f51 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 19 Oct 2022 15:39:43 -0500 Subject: Explain `SynapseError` and `FederationError` better (#14191) Explain `SynapseError` and `FederationError` better Spawning from https://github.com/matrix-org/synapse/pull/13816#discussion_r993262622 --- changelog.d/14191.doc | 1 + synapse/api/errors.py | 24 +++++++++++++++++++++--- synapse/federation/federation_server.py | 8 ++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14191.doc (limited to 'synapse') diff --git a/changelog.d/14191.doc b/changelog.d/14191.doc new file mode 100644 index 0000000000..6b0eeb1ae1 --- /dev/null +++ b/changelog.d/14191.doc @@ -0,0 +1 @@ +Update docstrings of `SynapseError` and `FederationError` to bettter describe what they are used for and the effects of using them are. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e0873b1913..400dd12aba 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -155,7 +155,13 @@ class RedirectException(CodeMessageException): class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error - message (as well as an HTTP status code). + message (as well as an HTTP status code). These often bubble all the way up to the + client API response so the error code and status often reach the client directly as + defined here. If the error doesn't make sense to present to a client, then it + probably shouldn't be a `SynapseError`. For example, if we contact another + homeserver over federation, we shouldn't automatically ferry response errors back to + the client on our end (a 500 from a remote server does not make sense to a client + when our server did not experience a 500). Attributes: errcode: Matrix error code e.g 'M_FORBIDDEN' @@ -600,8 +606,20 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": class FederationError(RuntimeError): - """This class is used to inform remote homeservers about erroneous - PDUs they sent us. + """ + Raised when we process an erroneous PDU. + + There are two kinds of scenarios where this exception can be raised: + + 1. We may pull an invalid PDU from a remote homeserver (e.g. during backfill). We + raise this exception to signal an error to the rest of the application. + 2. We may be pushed an invalid PDU as part of a `/send` transaction from a remote + homeserver. We raise so that we can respond to the transaction and include the + error string in the "PDU Processing Result". The message which will likely be + ignored by the remote homeserver and is not machine parse-able since it's just a + string. + + TODO: In the future, we should split these usage scenarios into their own error types. FATAL: The remote server could not interpret the source event. (e.g., it was missing a required field) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 28097664b4..59e351595b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -481,6 +481,14 @@ class FederationServer(FederationBase): pdu_results[pdu.event_id] = await process_pdu(pdu) async def process_pdu(pdu: EventBase) -> JsonDict: + """ + Processes a pushed PDU sent to us via a `/send` transaction + + Returns: + JsonDict representing a "PDU Processing Result" that will be bundled up + with the other processed PDU's in the `/send` transaction and sent back + to remote homeserver. + """ event_id = pdu.event_id with nested_logging_context(event_id): try: -- cgit 1.5.1 From da2c93d4b69200c1ea9fb94ec3c951fd4b424864 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 20 Oct 2022 15:17:45 +0100 Subject: Stop returning `unsigned.invite_room_state` in `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` responses (#14064) Co-authored-by: David Robertson --- changelog.d/14064.bugfix | 1 + synapse/federation/transport/server/federation.py | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 changelog.d/14064.bugfix (limited to 'synapse') diff --git a/changelog.d/14064.bugfix b/changelog.d/14064.bugfix new file mode 100644 index 0000000000..cce6ef3b71 --- /dev/null +++ b/changelog.d/14064.bugfix @@ -0,0 +1 @@ + Fix a long-standing bug where Synapse would accidentally include extra information in the response to [`PUT /_matrix/federation/v2/invite/{roomId}/{eventId}`](https://spec.matrix.org/v1.4/server-server-api/#put_matrixfederationv2inviteroomideventid). \ No newline at end of file diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 6f11138b57..205fd16daa 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -499,6 +499,11 @@ class FederationV2InviteServlet(BaseFederationServerServlet): result = await self.handler.on_invite_request( origin, event, room_version_id=room_version ) + + # We only store invite_room_state for internal use, so remove it before + # returning the event to the remote homeserver. + result["event"].get("unsigned", {}).pop("invite_room_state", None) + return 200, result -- cgit 1.5.1 From 755bfeee3a1ac7077045ab9e5a994b6ca89afba3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Oct 2022 11:32:47 -0400 Subject: Use servlets for /key/ endpoints. (#14229) To fix the response for unknown endpoints under that prefix. See MSC3743. --- changelog.d/14229.misc | 1 + synapse/api/urls.py | 2 +- synapse/app/generic_worker.py | 20 +++----- synapse/app/homeserver.py | 26 ++++------ synapse/rest/key/v2/__init__.py | 19 ++++--- synapse/rest/key/v2/local_key_resource.py | 22 ++++---- synapse/rest/key/v2/remote_key_resource.py | 73 +++++++++++++++------------ tests/app/test_openid_listener.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- 9 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14229.misc (limited to 'synapse') diff --git a/changelog.d/14229.misc b/changelog.d/14229.misc new file mode 100644 index 0000000000..b9cd9a34d5 --- /dev/null +++ b/changelog.d/14229.misc @@ -0,0 +1 @@ +Refactor `/key/` endpoints to use `RestServlet` classes. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bd49fa6a5f..a918579f50 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" -SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" +SERVER_KEY_PREFIX = "/_matrix/key" MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index dc49840f73..2a9f039367 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -28,7 +28,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, ) from synapse.app import _base from synapse.app._base import ( @@ -89,7 +89,7 @@ from synapse.rest.client.register import ( RegistrationTokenValidityRestServlet, ) from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -325,13 +325,13 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) + resources[CLIENT_API_PREFIX] = resource resources.update(build_synapse_client_resource_tree(self)) - resources.update({"/.well-known": well_known_resource(self)}) + resources["/.well-known"] = well_known_resource(self) elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + resources[FEDERATION_PREFIX] = TransportLayerServer(self) elif name == "media": if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() @@ -359,16 +359,12 @@ class GenericWorkerServer(HomeServer): # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883f2fd2ec..de3f08876f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -31,7 +31,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, STATIC_PREFIX, ) from synapse.app import _base @@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -215,30 +215,22 @@ class SynapseHomeServer(HomeServer): consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({"/_matrix/consent": consent_resource}) + resources["/_matrix/consent"] = consent_resource if name == "federation": federation_resource: Resource = TransportLayerServer(self) if compress: federation_resource = gz_wrap(federation_resource) - resources.update({FEDERATION_PREFIX: federation_resource}) + resources[FEDERATION_PREFIX] = federation_resource if name == "openid": - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["static", "client"]: - resources.update( - { - STATIC_PREFIX: StaticResource( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - } + resources[STATIC_PREFIX] = StaticResource( + os.path.join(os.path.dirname(synapse.__file__), "static") ) if name in ["media", "federation", "client"]: @@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "metrics" and self.config.metrics.enable_metrics: metrics_resource: Resource = MetricsResource(RegistryProxy) diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 7f8c1de1ff..26403facb8 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,17 +14,20 @@ from typing import TYPE_CHECKING -from twisted.web.resource import Resource - -from .local_key_resource import LocalKey -from .remote_key_resource import RemoteKey +from synapse.http.server import HttpServer, JsonResource +from synapse.rest.key.v2.local_key_resource import LocalKey +from synapse.rest.key.v2.remote_key_resource import RemoteKey if TYPE_CHECKING: from synapse.server import HomeServer -class KeyApiV2Resource(Resource): +class KeyResource(JsonResource): def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.putChild(b"server", LocalKey(hs)) - self.putChild(b"query", RemoteKey(hs)) + super().__init__(hs, canonical_json=True) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: + LocalKey(hs).register(http_server) + RemoteKey(hs).register(http_server) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 095993415c..d03e728d42 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,16 +13,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +import re +from typing import TYPE_CHECKING, Optional, Tuple -from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from twisted.web.resource import Resource +from twisted.web.server import Request -from synapse.http.server import respond_with_json_bytes -from synapse.http.site import SynapseRequest +from synapse.http.servlet import RestServlet from synapse.types import JsonDict if TYPE_CHECKING: @@ -31,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LocalKey(Resource): +class LocalKey(RestServlet): """HTTP resource containing encoding the TLS X.509 certificate and NACL signature verification keys for this server:: @@ -61,18 +60,17 @@ class LocalKey(Resource): } """ - isLeaf = True + PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P[^/]*))?$"),) def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) - Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) - self.response_body = encode_canonical_json(self.response_json_object()) + self.response_body = self.response_json_object() def response_json_object(self) -> JsonDict: verify_keys = {} @@ -99,9 +97,11 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: SynapseRequest) -> Optional[int]: + def on_GET( + self, request: Request, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes(request, 200, self.response_body) + return 200, self.response_body diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7f8ad29566..19820886f5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Set +import re +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from signedjson.sign import sign_json -from synapse.api.errors import Codes, SynapseError +from twisted.web.server import Request + from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.http.site import SynapseRequest +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, +) from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -32,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RemoteKey(DirectServeJsonResource): +class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource): } """ - isLeaf = True - def __init__(self, hs: "HomeServer"): - super().__init__() - self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: SynapseRequest) -> None: - assert request.postpath is not None - if len(request.postpath) == 1: - (server,) = request.postpath - query: dict = {server.decode("ascii"): {}} - elif len(request.postpath) == 2: - server, key_id = request.postpath + def register(self, http_server: HttpServer) -> None: + http_server.register_paths( + "GET", + ( + re.compile( + "^/_matrix/key/v2/query/(?P[^/]*)(/(?P[^/]*))?$" + ), + ), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "POST", + (re.compile("^/_matrix/key/v2/query$"),), + self.on_POST, + self.__class__.__name__, + ) + + async def on_GET( + self, request: Request, server: str, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: + if server and key_id: minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} + query = {server: {key_id: arguments}} else: - raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) + query = {server: {}} - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) query = content["server_keys"] - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - request: SynapseRequest, - query: JsonDict, - query_remote_on_cache_miss: bool = False, - ) -> None: + self, query: JsonDict, query_remote_on_cache_miss: bool = False + ) -> JsonDict: logger.info("Handling query for keys %r", query) store_queries = [] @@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource): for server_name, keys in cache_misses.items() ), ) - await self.query_keys(request, query, query_remote_on_cache_miss=False) + return await self.query_keys(query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json_raw in json_results: @@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource): signed_keys.append(key_json) - response = {"server_keys": signed_keys} - - respond_with_json(request, 200, response, canonical_json=True) + return {"server_keys": signed_keys} diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index c7dae58eb5..8d03da7f96 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): self.assertEqual(channel.code, 401) -@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock()) +@patch("synapse.app.homeserver.KeyResource", new=Mock()) class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index ac0ac06b7e..7f1fba1086 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict @@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def create_test_resource(self) -> Resource: return create_resource_tree( - {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource() ) def expect_outgoing_key_request( -- cgit 1.5.1 From fab495a9e1442d99e922367f65f41de5eaa488eb Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Fri, 21 Oct 2022 08:49:47 +0000 Subject: Fix event size checks (#13710) --- changelog.d/13710.bugfix | 1 + synapse/event_auth.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/13710.bugfix (limited to 'synapse') diff --git a/changelog.d/13710.bugfix b/changelog.d/13710.bugfix new file mode 100644 index 0000000000..4c318d15f5 --- /dev/null +++ b/changelog.d/13710.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would count codepoints instead of bytes when validating the size of some fields. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index bab31e33c5..5036604036 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -342,15 +342,15 @@ def check_state_dependent_auth_rules( def _check_size_limits(event: "EventBase") -> None: - if len(event.user_id) > 255: + if len(event.user_id.encode("utf-8")) > 255: raise EventSizeError("'user_id' too large") - if len(event.room_id) > 255: + if len(event.room_id.encode("utf-8")) > 255: raise EventSizeError("'room_id' too large") - if event.is_state() and len(event.state_key) > 255: + if event.is_state() and len(event.state_key.encode("utf-8")) > 255: raise EventSizeError("'state_key' too large") - if len(event.type) > 255: + if len(event.type.encode("utf-8")) > 255: raise EventSizeError("'type' too large") - if len(event.event_id) > 255: + if len(event.event_id.encode("utf-8")) > 255: raise EventSizeError("'event_id' too large") if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: raise EventSizeError("event too large") -- cgit 1.5.1 From 1433b5d5b64c3a6624e6e4ff4fef22127c49df86 Mon Sep 17 00:00:00 2001 From: Tadeusz Sośnierz Date: Fri, 21 Oct 2022 14:52:44 +0200 Subject: Show erasure status when listing users in the Admin API (#14205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Show erasure status when listing users in the Admin API * Use USING when joining erased_users * Add changelog entry * Revert "Use USING when joining erased_users" This reverts commit 30bd2bf106415caadcfdbdd1b234ef2b106cc394. * Make the erased check work on postgres * Add a testcase for showing erased user status * Appease the style linter * Explicitly convert `erased` to bool to make SQLite consistent with Postgres This also adds us an easy way in to fix the other accidentally integered columns. * Move erasure status test to UsersListTestCase * Include user erased status when fetching user info via the admin API * Document the erase status in user_admin_api * Appease the linter and mypy * Signpost comments in tests Co-authored-by: Tadeusz Sośnierz Co-authored-by: David Robertson --- changelog.d/14205.feature | 1 + docs/admin_api/user_admin_api.md | 4 ++++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/__init__.py | 13 +++++++++-- tests/rest/admin/test_user.py | 35 +++++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14205.feature (limited to 'synapse') diff --git a/changelog.d/14205.feature b/changelog.d/14205.feature new file mode 100644 index 0000000000..6692063352 --- /dev/null +++ b/changelog.d/14205.feature @@ -0,0 +1 @@ +Show erasure status when listing users in the Admin API. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 3625c7b6c5..c95d6c9b05 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -37,6 +37,7 @@ It returns a JSON body like the following: "is_guest": 0, "admin": 0, "deactivated": 0, + "erased": false, "shadow_banned": 0, "creation_ts": 1560432506, "appservice_id": null, @@ -167,6 +168,7 @@ A response body like the following is returned: "admin": 0, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": null, @@ -177,6 +179,7 @@ A response body like the following is returned: "admin": 1, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": "", @@ -247,6 +250,7 @@ The following fields are returned in the JSON response body: - `user_type` - string - Type of the user. Normal users are type `None`. This allows user type specific behaviour. There are also types `support` and `bot`. - `deactivated` - bool - Status if that user has been marked as deactivated. + - `erased` - bool - Status if that user has been marked as erased. - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f2989cc4a2..5bf8e86387 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -100,6 +100,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids + user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) return user_info_dict diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a62b4abd4e..cfaedf5e0c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -201,7 +201,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: str = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.NAME.value, direction: str = "f", approved: bool = True, ) -> Tuple[List[JsonDict], int]: @@ -261,6 +261,7 @@ class DataStore( sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN erased_users AS eu ON u.name = eu.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -269,7 +270,8 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts, approved + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, + eu.user_id is not null as erased {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? @@ -277,6 +279,13 @@ class DataStore( args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) + + # some of those boolean values are returned as integers when we're on SQLite + columns_to_boolify = ["erased"] + for user in users: + for column in columns_to_boolify: + user[column] = bool(user[column]) + return users, count return await self.db_pool.runInteraction( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4c1ce33463..63410ffdf1 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersions from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -924,6 +924,36 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def test_erasure_status(self) -> None: + # Create a new user. + user_id = self.register_user("eraseme", "eraseme") + + # They should appear in the list users API, marked as not erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], False) + + # Deactivate that user, requesting erasure. + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + user_id, erase_data=True, requester=create_requester(user_id) + ) + ) + + # Repeat the list users query. They should now be marked as erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], True) + def _order_test( self, expected_user_list: List[str], @@ -1195,6 +1225,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) + self.assertFalse(channel.json_body["erased"]) # Deactivate and erase user channel = self.make_request( @@ -1219,6 +1250,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) + self.assertTrue(channel.json_body["erased"]) self._is_erased("@user:test", True) @@ -2757,6 +2789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("avatar_url", content) self.assertIn("admin", content) self.assertIn("deactivated", content) + self.assertIn("erased", content) self.assertIn("shadow_banned", content) self.assertIn("creation_ts", content) self.assertIn("appservice_id", content) -- cgit 1.5.1 From 4dd7aa371b6bc746fa4b0a9af220b2013b17a45d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Oct 2022 09:11:19 -0400 Subject: Properly update the threads table when thread events are redacted. (#14248) When the last event in a thread is redacted we need to update the threads table: * Find the new latest event in the thread and store it into the table; or * Remove the thread from the table if it is no longer a thread (i.e. all events in the thread were redacted). --- changelog.d/14248.bugfix | 1 + synapse/storage/databases/main/events.py | 61 ++++++++++++++--- tests/rest/client/test_relations.py | 110 +++++++++++++++++++++---------- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/14248.bugfix (limited to 'synapse') diff --git a/changelog.d/14248.bugfix b/changelog.d/14248.bugfix new file mode 100644 index 0000000000..203c52c16b --- /dev/null +++ b/changelog.d/14248.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where the information returned from the `/threads` API could be stale when threaded events are redacted. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6698cbf664..00880bb37d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2028,25 +2028,37 @@ class PersistEventsStore: redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2057,9 +2069,38 @@ class PersistEventsStore: txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ddf315b894..e3d801f7a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1523,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _get_threads(self) -> List[Tuple[str, str]]: + """Request the threads in the room and returns a list of thread ID and latest event ID.""" + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + threads = channel.json_body["chunk"] + return [ + ( + t["event_id"], + t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][ + "event_id" + ], + ) + for t in threads + ] + def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1567,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase): The redacted event should not be included in bundled aggregations or the response to relations. """ - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 1", "msgtype": "m.text"}, - ) - unredacted_event_id = channel.json_body["event_id"] + # Create a thread with a few events in it. + thread_replies = [] + for i in range(3): + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": f"reply {i}", "msgtype": "m.text"}, + ) + thread_replies.append(channel.json_body["event_id"]) - # Note that the *last* event in the thread is redacted, as that gets - # included in the bundled aggregation. - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 2", "msgtype": "m.text"}, + ################################################## + # Check the test data is configured as expected. # + ################################################## + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + relations = self._get_bundled_aggregations() + self.assertDictContainsSubset( + {"count": 3, "current_user_participated": True}, + relations[RelationTypes.THREAD], + ) + # The latest event is the last sent event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + thread_replies[-1], ) - to_redact_event_id = channel.json_body["event_id"] - # Both relations exist. - event_ids = self._get_related_events() + # There should be one thread, the latest event is the event that will be redacted. + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + ########################## + # Redact the last event. # + ########################## + self._redact(thread_replies.pop()) + + # The thread should still exist, but the latest event should be updated. + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 2, - "current_user_participated": True, - }, + {"count": 2, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event returned is the event that will be redacted. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - to_redact_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) - # Redact one of the reactions. - self._redact(to_redact_event_id) + ########################################### + # Redact the *first* event in the thread. # + ########################################### + self._redact(thread_replies.pop(0)) - # The unredacted relation should still exist. - event_ids = self._get_related_events() + # Nothing should have changed (except the thread count). + self.assertEquals(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 1, - "current_user_participated": True, - }, + {"count": 1, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event is now the unredacted event. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - unredacted_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + #################################### + # Redact the last remaining event. # + #################################### + self._redact(thread_replies.pop(0)) + self.assertEquals(thread_replies, []) + + # The event should no longer be considered a thread. + self.assertEquals(self._get_related_events(), []) + self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event -- cgit 1.5.1 From d24346f53055eae7fb8e9038ef35fa843790742b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 21 Oct 2022 16:03:44 +0100 Subject: Fix logging error on SIGHUP (#14258) --- changelog.d/14258.bugfix | 2 ++ synapse/app/_base.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14258.bugfix (limited to 'synapse') diff --git a/changelog.d/14258.bugfix b/changelog.d/14258.bugfix new file mode 100644 index 0000000000..de97945844 --- /dev/null +++ b/changelog.d/14258.bugfix @@ -0,0 +1,2 @@ +Fix a bug introduced in Synapse 1.60.0 which caused an error to be logged when Synapse received a SIGHUP signal, and debug logging was enabled. + diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 000912e86e..a683ebf4cb 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -558,7 +558,7 @@ def reload_cache_config(config: HomeServerConfig) -> None: logger.warning(f) else: logger.debug( - "New cache config. Was:\n %s\nNow:\n", + "New cache config. Was:\n %s\nNow:\n %s", previous_cache_config.__dict__, config.caches.__dict__, ) -- cgit 1.5.1 From 1d45ad8b2ab1c41dd489ccd581d027077bc917e5 Mon Sep 17 00:00:00 2001 From: Germain Date: Fri, 21 Oct 2022 18:44:00 +0100 Subject: Improve aesthetics and reusability of HTML templates. (#13652) Use a base template to create a cohesive feel across the HTML templates provided by Synapse. Adds basic styling to the base template for a more user-friendly look and feel. --- changelog.d/13652.feature | 1 + synapse/res/templates/_base.html | 29 ++ .../res/templates/account_previously_renewed.html | 18 +- synapse/res/templates/account_renewed.html | 18 +- synapse/res/templates/add_threepid.html | 22 +- synapse/res/templates/add_threepid_failure.html | 20 +- synapse/res/templates/add_threepid_success.html | 18 +- synapse/res/templates/auth_success.html | 28 +- synapse/res/templates/invalid_token.html | 17 +- synapse/res/templates/notice_expiry.html | 93 +++--- synapse/res/templates/notif_mail.html | 116 ++++--- synapse/res/templates/password_reset.html | 19 +- .../res/templates/password_reset_confirmation.html | 14 +- synapse/res/templates/password_reset_failure.html | 14 +- synapse/res/templates/password_reset_success.html | 12 +- synapse/res/templates/recaptcha.html | 19 +- synapse/res/templates/registration.html | 21 +- synapse/res/templates/registration_failure.html | 12 +- synapse/res/templates/registration_success.html | 13 +- synapse/res/templates/registration_token.html | 16 +- synapse/res/templates/sso_account_deactivated.html | 49 ++- .../res/templates/sso_auth_account_details.html | 372 ++++++++++----------- synapse/res/templates/sso_auth_bad_user.html | 52 ++- synapse/res/templates/sso_auth_confirm.html | 56 ++-- synapse/res/templates/sso_auth_success.html | 54 ++- synapse/res/templates/sso_error.html | 34 +- synapse/res/templates/sso_login_idp_picker.html | 114 +++---- synapse/res/templates/sso_new_user_consent.html | 60 ++-- synapse/res/templates/sso_redirect_confirm.html | 75 ++--- synapse/res/templates/style.css | 29 ++ synapse/res/templates/terms.html | 16 +- 31 files changed, 691 insertions(+), 740 deletions(-) create mode 100644 changelog.d/13652.feature create mode 100644 synapse/res/templates/_base.html create mode 100644 synapse/res/templates/style.css (limited to 'synapse') diff --git a/changelog.d/13652.feature b/changelog.d/13652.feature new file mode 100644 index 0000000000..bc7f2926dc --- /dev/null +++ b/changelog.d/13652.feature @@ -0,0 +1 @@ +Improve aesthetics of HTML templates. Note that these changes do not retroactively apply to templates which have been [customised](https://matrix-org.github.io/synapse/latest/templates.html#templates) by server admins. \ No newline at end of file diff --git a/synapse/res/templates/_base.html b/synapse/res/templates/_base.html new file mode 100644 index 0000000000..46439fce6a --- /dev/null +++ b/synapse/res/templates/_base.html @@ -0,0 +1,29 @@ + + + + + + + {% block title %}{% endblock %} + + {% block header %}{% endblock %} + + +
+ {% if app_name == "Riot" %} + [Riot] + {% elif app_name == "Vector" %} + [Vector] + {% elif app_name == "Element" %} + [Element] + {% else %} + [matrix] + {% endif %} +
+ +{% block body %}{% endblock %} + + + diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html index bd4f7cea97..91582a8af0 100644 --- a/synapse/res/templates/account_previously_renewed.html +++ b/synapse/res/templates/account_previously_renewed.html @@ -1,12 +1,6 @@ - - - - - - - Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - - Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - \ No newline at end of file +{% extends "_base.html" %} +{% block title %}Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.{% endblock %} + +{% block body %} +

Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.

+{% endblock %} diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html index 57b319f375..18a57833f1 100644 --- a/synapse/res/templates/account_renewed.html +++ b/synapse/res/templates/account_renewed.html @@ -1,12 +1,6 @@ - - - - - - - Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - - Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - \ No newline at end of file +{% extends "_base.html" %} +{% block title %}Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.{% endblock %} + +{% block body %} +

Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.

+{% endblock %} diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html index 71f2215b7a..33c883936a 100644 --- a/synapse/res/templates/add_threepid.html +++ b/synapse/res/templates/add_threepid.html @@ -1,14 +1,8 @@ - - - - - - - Request to add an email address to your Matrix account - - -

A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:

- {{ link }} -

If this was not you, you can safely ignore this email. Thank you.

- - +{% extends "_base.html" %} +{% block title %}Request to add an email address to your Matrix account{% endblock %} + +{% block body %} +

A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:

+{{ link }} +

If this was not you, you can safely ignore this email. Thank you.

+{% endblock %} diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html index bd627ee9ce..f6d7e33825 100644 --- a/synapse/res/templates/add_threepid_failure.html +++ b/synapse/res/templates/add_threepid_failure.html @@ -1,13 +1,7 @@ - - - - - - - Request failed - - -

The request failed for the following reason: {{ failure_reason }}.

-

No changes have been made to your account.

- - +{% extends "_base.html" %} +{% block title %}Request failed{% endblock %} + +{% block body %} +

The request failed for the following reason: {{ failure_reason }}.

+

No changes have been made to your account.

+{% endblock %} diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html index 49170c138e..6d45111796 100644 --- a/synapse/res/templates/add_threepid_success.html +++ b/synapse/res/templates/add_threepid_success.html @@ -1,12 +1,6 @@ - - - - - - - Your email has now been validated - - -

Your email has now been validated, please return to your client. You may now close this window.

- - \ No newline at end of file +{% extends "_base.html" %} +{% block title %}Your email has now been validated{% endblock %} + +{% block body %} +

Your email has now been validated, please return to your client. You may now close this window.

+{% endblock %} diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html index 2d6ac44a0e..9178332f59 100644 --- a/synapse/res/templates/auth_success.html +++ b/synapse/res/templates/auth_success.html @@ -1,21 +1,21 @@ - - -Success! - - +{% extends "_base.html" %} +{% block title %}Success!{% endblock %} + +{% block header %} - - -
-

Thank you

-

You may now close this window and return to the application

-
- - +{% endblock %} + +{% block body %} +
+

Thank you

+

You may now close this window and return to the application

+
+ +{% endblock %} diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html index 2c7c384fe3..d0b1dae669 100644 --- a/synapse/res/templates/invalid_token.html +++ b/synapse/res/templates/invalid_token.html @@ -1,12 +1,5 @@ - - - - - - - Invalid renewal token. - - - Invalid renewal token. - - +{% block title %}Invalid renewal token.{% endblock %} + +{% block body %} +

Invalid renewal token.

+{% endblock %} diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html index 865f9f7ada..406397aaca 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html @@ -1,47 +1,46 @@ - - - - - - - - - - - - - - -
- - - - - - - - -
-
Hi {{ display_name }},
-
-
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
-
To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):
- -
-
- - +{% extends "_base.html" %} +{% block title %}Notice of expiry{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} + + + + + + +
+ + + + + + + + +
+
Hi {{ display_name }},
+
+
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
+
To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):
+ +
+
+{% endblock %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index 9dba0c0253..939d40315f 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -1,59 +1,57 @@ - - - - - - - - - - - - - - -
- - - - - -
-
Hi {{ user_display_name }},
-
{{ summary_text }}
-
- {%- for room in rooms %} - {%- include 'room.html' with context %} - {%- endfor %} - -
- - +{% block title %}New activity in room{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} + + + + + + +
+ + + + + +
+
Hi {{ user_display_name }},
+
{{ summary_text }}
+
+ {%- for room in rooms %} + {%- include 'room.html' with context %} + {%- endfor %} + +
+{% endblock %} diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html index a8bdce357b..de5a9ec68f 100644 --- a/synapse/res/templates/password_reset.html +++ b/synapse/res/templates/password_reset.html @@ -1,14 +1,9 @@ - - - Password reset - - - - -

A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:

+{% block title %}Password reset{% endblock %} - {{ link }} +{% block body %} +

A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:

-

If this was not you, do not click the link above and instead contact your server administrator. Thank you.

- - +{{ link }} + +

If this was not you, do not click the link above and instead contact your server administrator. Thank you.

+{% endblock %} diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html index 2e3fd2ec1e..0eac64b6a8 100644 --- a/synapse/res/templates/password_reset_confirmation.html +++ b/synapse/res/templates/password_reset_confirmation.html @@ -1,10 +1,6 @@ - - - Password reset confirmation - - - - +{% block title %}Password reset confirmation{% endblock %} + +{% block body %}
@@ -15,6 +11,4 @@ If you did not mean to do this, please close this page and your password will not be changed.

- - - +{% endblock %} diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html index 2d59c463f0..977babdb40 100644 --- a/synapse/res/templates/password_reset_failure.html +++ b/synapse/res/templates/password_reset_failure.html @@ -1,12 +1,6 @@ - - - Password reset failure - - - - -

The request failed for the following reason: {{ failure_reason }}.

+{% block title %}Password reset failure{% endblock %} +{% block body %} +

The request failed for the following reason: {{ failure_reason }}.

Your password has not been reset.

- - +{% endblock %} diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html index 5165bd1fa2..0e99fad7ff 100644 --- a/synapse/res/templates/password_reset_success.html +++ b/synapse/res/templates/password_reset_success.html @@ -1,9 +1,5 @@ - - - - - - +{% block title %}Password reset success{% endblock %} + +{% block body %}

Your email has now been validated, please return to your client to reset your password. You may now close this window.

- - +{% endblock %} diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html index 615d3239c6..feaf3f6aed 100644 --- a/synapse/res/templates/recaptcha.html +++ b/synapse/res/templates/recaptcha.html @@ -1,10 +1,7 @@ - - -Authentication - - - +{% block title %}Authentication{% endblock %} + +{% block header %} + - - +{% endblock %} + +{% block body %}
{% if error is defined %} @@ -37,5 +35,4 @@ function captchaDone() {
- - +{% endblock %} \ No newline at end of file diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html index 20e831ff4a..189960a832 100644 --- a/synapse/res/templates/registration.html +++ b/synapse/res/templates/registration.html @@ -1,16 +1,11 @@ - - - Registration - - - - -

You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:

+{% block title %}Registration{% endblock %} - Verify Your Email Address +{% block body %} +

You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:

-

If this was not you, you can safely disregard this email.

+Verify Your Email Address -

Thank you.

- - +

If this was not you, you can safely disregard this email.

+ +

Thank you.

+{% endblock %} diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html index a6ed22bc90..3debe9301d 100644 --- a/synapse/res/templates/registration_failure.html +++ b/synapse/res/templates/registration_failure.html @@ -1,9 +1,5 @@ - - - - - - +{% block title %}Registration failure{% endblock %} + +{% block body %}

Validation failed for the following reason: {{ failure_reason }}.

- - +{% endblock %} diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html index d51d5549d8..e2dd020a9e 100644 --- a/synapse/res/templates/registration_success.html +++ b/synapse/res/templates/registration_success.html @@ -1,10 +1,5 @@ - - - Your email has now been validated - - - - +{% block title %}Your email has now been validated{% endblock %} + +{% block body %}

Your email has now been validated, please return to your client. You may now close this window.

- - +{% endblock %} diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html index 59a98f564c..2ee5866ba5 100644 --- a/synapse/res/templates/registration_token.html +++ b/synapse/res/templates/registration_token.html @@ -1,11 +1,10 @@ - - -Authentication - - +{% block title %}Authentication{% endblock %} + +{% block header %} - - +{% endblock %} + +{% block body %}
{% if error is defined %} @@ -19,5 +18,4 @@
- - +{% endblock %} diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html index 075f801cec..c634229840 100644 --- a/synapse/res/templates/sso_account_deactivated.html +++ b/synapse/res/templates/sso_account_deactivated.html @@ -1,25 +1,24 @@ - - - - - SSO account deactivated - - - - -
-

Your account has been deactivated

-

- No account found -

-

- Your account might have been deactivated by the server administrator. - You can either try to create a new account or contact the server’s - administrator. -

-
- {% include "sso_footer.html" without context %} - - +{% block title %}SSO account deactivated{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+
+

Your account has been deactivated

+

+ No account found +

+

+ Your account might have been deactivated by the server administrator. + You can either try to create a new account or contact the server’s + administrator. +

+
+
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index 2d1db386e1..b516333373 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -1,189 +1,185 @@ - - - - Create your account - - - - - - - -
-

Create your account

-

This is required. Continue to create your account on {{ server_name }}. You can't change this later.

-
-
-
-
- -
@
- -
:{{ server_name }}
+{% block title %}Create your account{% endblock %} + +{% block header %} + + +{% endblock %} + +{% block body %} +
+

Create your account

+

This is required. Continue to create your account on {{ server_name }}. You can't change this later.

+
+
+ +
+ +
@
+ +
:{{ server_name }}
+
+ + + {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} +
+

{% if idp.idp_icon %}{% endif %}Optional data from {{ idp.idp_name }}

+ {% if user_attributes.avatar_url %} +
- {% include "sso_footer.html" without context %} - - - + + + {% endif %} + {% if user_attributes.display_name %} + + {% endif %} + {% for email in user_attributes.emails %} + + {% endfor %} + + {% endif %} + +
+{% include "sso_footer.html" without context %} + +{% endblock %} diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html index 94403fc3ce..69fdcc9ef0 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -1,27 +1,25 @@ - - - - - Authentication failed - - - - - -
-

That doesn't look right

-

- We were unable to validate your {{ server_name }} account - via single sign‑on (SSO), because the SSO Identity - Provider returned different details than when you logged in. -

-

- Try the operation again, and ensure that you use the same details on - the Identity Provider as when you log into your account. -

-
- {% include "sso_footer.html" without context %} - - +{% block title %}Authentication failed{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+
+

That doesn't look right

+

+ We were unable to validate your {{ server_name }} account + via single sign‑on (SSO), because the SSO Identity + Provider returned different details than when you logged in. +

+

+ Try the operation again, and ensure that you use the same details on + the Identity Provider as when you log into your account. +

+
+
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index aa1c974a6b..2d106e0ae4 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -1,30 +1,26 @@ - - - - - Confirm it's you - - - - - -
-

Confirm it's you to continue

-

- A client is trying to {{ description }}. To confirm this action - re-authorize your account with single sign-on. -

-

- If you did not expect this, your account may be compromised. -

-
-
- - Continue with {{ idp.idp_name }} - -
- {% include "sso_footer.html" without context %} - - +{% block title %}Confirm it's you{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+

Confirm it's you to continue

+

+ A client is trying to {{ description }}. To confirm this action + re-authorize your account with single sign-on. +

+

+ If you did not expect this, your account may be compromised. +

+
+
+ + Continue with {{ idp.idp_name }} + +
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html index 4898af6011..56150eaefe 100644 --- a/synapse/res/templates/sso_auth_success.html +++ b/synapse/res/templates/sso_auth_success.html @@ -1,29 +1,25 @@ - - - - - Authentication successful - - - - - - -
-

Thank you

-

- Now we know it’s you, you can close this window and return to the - application. -

-
- {% include "sso_footer.html" without context %} - - +{% block title %}Authentication successful{% endblock %} + +{% block header %} + + +{% endblock %} + +{% block body %} +
+

Thank you

+

+ Now we know it’s you, you can close this window and return to the + application. +

+
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index 19992ff2ad..e394a92623 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -1,19 +1,19 @@ - - - - - Authentication failed - - - - - +{% block header %} +{% if error == "unauthorised" %} + +{% endif %} +{% endblock %} + +{% block body %} +
{# If an error of unauthorised is returned it means we have actively rejected their login #} {% if error == "unauthorised" %}
@@ -66,5 +66,5 @@ } {% endif %} - - +
+{% endblock %} diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index 56fabfa3d2..a2772ca9ef 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -1,63 +1,59 @@ - - - - - - - Choose identity provider - - - -
-

Log in to {{ server_name }}

-

Choose an identity provider to log in

-
-
- -
- {% include "sso_footer.html" without context %} - - + .providers a { + display: block; + border-radius: 4px; + border: 1px solid #17191C; + padding: 8px; + text-align: center; + text-decoration: none; + color: #17191C; + display: flex; + align-items: center; + font-weight: bold; + } + + .providers a img { + width: 24px; + height: 24px; + } + .providers a span { + flex: 1; + } + +{% endblock %} + +{% block body %} +
+

Log in to {{ server_name }}

+

Choose an identity provider to log in

+
+
+ +
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html index 523f64c4fc..126887d26c 100644 --- a/synapse/res/templates/sso_new_user_consent.html +++ b/synapse/res/templates/sso_new_user_consent.html @@ -1,33 +1,29 @@ - - - - - Agree to terms and conditions - - - - - -
-

Your account is nearly ready

-

Agree to the terms to create your account.

-
-
- {% include "sso_partial_profile.html" %} - -
- {% include "sso_footer.html" without context %} - - +{% block header %} + +{% endblock %} + +{% block body %} +
+

Your account is nearly ready

+

Agree to the terms to create your account.

+
+
+ {% include "sso_partial_profile.html" %} + +
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index 1049a9bd92..887ee0d294 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -1,41 +1,38 @@ - - - - - Continue to your account - - - - - -
-

Continue to your account

-
-
- {% include "sso_partial_profile.html" %} -

Continuing will grant {{ display_url }} access to your account.

- Continue -
- {% include "sso_footer.html" without context %} - - + .confirm-trust { + margin: 34px 0; + color: #8D99A5; + } + .confirm-trust strong { + color: #17191C; + } + + .confirm-trust::before { + content: ""; + background-image: url('data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTgiIGhlaWdodD0iMTgiIHZpZXdCb3g9IjAgMCAxOCAxOCIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZmlsbC1ydWxlPSJldmVub2RkIiBjbGlwLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xNi41IDlDMTYuNSAxMy4xNDIxIDEzLjE0MjEgMTYuNSA5IDE2LjVDNC44NTc4NiAxNi41IDEuNSAxMy4xNDIxIDEuNSA5QzEuNSA0Ljg1Nzg2IDQuODU3ODYgMS41IDkgMS41QzEzLjE0MjEgMS41IDE2LjUgNC44NTc4NiAxNi41IDlaTTcuMjUgOUM3LjI1IDkuNDY1OTYgNy41Njg2OSA5Ljg1NzQ4IDggOS45Njg1VjEyLjM3NUM4IDEyLjkyNzMgOC40NDc3MiAxMy4zNzUgOSAxMy4zNzVIMTAuMTI1QzEwLjY3NzMgMTMuMzc1IDExLjEyNSAxMi45MjczIDExLjEyNSAxMi4zNzVDMTEuMTI1IDExLjgyMjcgMTAuNjc3MyAxMS4zNzUgMTAuMTI1IDExLjM3NUgxMFY5QzEwIDguOTY1NDggOS45OTgyNSA4LjkzMTM3IDkuOTk0ODQgOC44OTc3NkM5Ljk0MzYzIDguMzkzNSA5LjUxNzc3IDggOSA4SDguMjVDNy42OTc3MiA4IDcuMjUgOC40NDc3MiA3LjI1IDlaTTkgNy41QzkuNjIxMzIgNy41IDEwLjEyNSA2Ljk5NjMyIDEwLjEyNSA2LjM3NUMxMC4xMjUgNS43NTM2OCA5LjYyMTMyIDUuMjUgOSA1LjI1QzguMzc4NjggNS4yNSA3Ljg3NSA1Ljc1MzY4IDcuODc1IDYuMzc1QzcuODc1IDYuOTk2MzIgOC4zNzg2OCA3LjUgOSA3LjVaIiBmaWxsPSIjQzFDNkNEIi8+Cjwvc3ZnPgoK'); + background-repeat: no-repeat; + width: 24px; + height: 24px; + display: block; + float: left; + } + +{% endblock %} + +{% block body %} +
+

Continue to your account

+
+
+ {% include "sso_partial_profile.html" %} +

Continuing will grant {{ display_url }} access to your account.

+ Continue +
+{% include "sso_footer.html" without context %} + +{% endblock %} diff --git a/synapse/res/templates/style.css b/synapse/res/templates/style.css new file mode 100644 index 0000000000..097b235ae5 --- /dev/null +++ b/synapse/res/templates/style.css @@ -0,0 +1,29 @@ +html { + height: 100%; +} + +body { + background: #f9fafb; + max-width: 680px; + margin: auto; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol"; +} + +.mx_Header { + border-bottom: 3px solid #ddd; + margin-bottom: 1rem; + padding-top: 1rem; + padding-bottom: 1rem; + text-align: center; +} + +@media screen and (max-width: 1120px) { + body { + font-size: 20px; + } + + h1 { font-size: 1rem; } + h2 { font-size: .9rem; } + h3 { font-size: .85rem; } + h4 { font-size: .8rem; } +} diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html index 2081d990ab..977c3d0bc7 100644 --- a/synapse/res/templates/terms.html +++ b/synapse/res/templates/terms.html @@ -1,11 +1,10 @@ - - -Authentication - - +{% block title %}Authentication{% endblock %} + +{% block header %} - - +{% endblock %} + +{% block body %}
{% if error is defined %} @@ -19,5 +18,4 @@
- - +{% endblock %} -- cgit 1.5.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'synapse') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.5.1 From 1469fed0e39d31a063e8a54c2ea027774eec6acb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 24 Oct 2022 10:45:10 +0100 Subject: Add debugging to help diagnose lost device-list-update (#14268) --- changelog.d/14268.misc | 1 + synapse/storage/databases/main/devices.py | 54 +++++++++++++++++++++---------- 2 files changed, 38 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14268.misc (limited to 'synapse') diff --git a/changelog.d/14268.misc b/changelog.d/14268.misc new file mode 100644 index 0000000000..894b1e1d4c --- /dev/null +++ b/changelog.d/14268.misc @@ -0,0 +1 @@ +Add debugging to help diagnose lost device-list-update. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 830b076a32..979dd4e17e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -274,6 +274,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): destination, int(from_stream_id) ) if not has_changed: + # debugging for https://github.com/matrix-org/synapse/issues/14251 + issue_8631_logger.debug( + "%s: no change between %i and %i", + destination, + from_stream_id, + now_stream_id, + ) return now_stream_id, [] updates = await self.db_pool.runInteraction( @@ -1848,7 +1855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Iterable[str], + device_id: str, hosts: Collection[str], stream_ids: List[int], context: Optional[Dict[str, str]], @@ -1864,6 +1871,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id_iterator = iter(stream_ids) encoded_context = json_encoder.encode(context) + mark_sent = not self.hs.is_mine_id(user_id) + + values = [ + ( + destination, + next(stream_id_iterator), + user_id, + device_id, + mark_sent, + now, + encoded_context if whitelisted_homeserver(destination) else "{}", + ) + for destination in hosts + ] + self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1876,23 +1898,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "ts", "opentracing_context", ), - values=[ - ( - destination, - next(stream_id_iterator), - user_id, - device_id, - not self.hs.is_mine_id( - user_id - ), # We only need to send out update for *our* users - now, - encoded_context if whitelisted_homeserver(destination) else "{}", - ) - for destination in hosts - for device_id in device_ids - ], + values=values, ) + # debugging for https://github.com/matrix-org/synapse/issues/14251 + if issue_8631_logger.isEnabledFor(logging.DEBUG): + issue_8631_logger.debug( + "Recorded outbound pokes for %s:%s with device stream ids %s", + user_id, + device_id, + { + stream_id: destination + for (destination, stream_id, _, _, _, _, _) in values + }, + ) + def _add_device_outbound_room_poke_txn( self, txn: LoggingTransaction, @@ -1997,7 +2017,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._add_device_outbound_poke_to_stream_txn( txn, user_id=user_id, - device_ids=[device_id], + device_id=device_id, hosts=hosts, stream_ids=stream_ids, context=context, -- cgit 1.5.1 From 09b588854e3a6abc4ea2eaa68bb0345f23be5ce8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 24 Oct 2022 13:05:14 +0100 Subject: Fix `TypeError: 'dict_keys' object is not reversible` (#14280) --- changelog.d/14280.bugfix | 1 + synapse/federation/sender/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14280.bugfix (limited to 'synapse') diff --git a/changelog.d/14280.bugfix b/changelog.d/14280.bugfix new file mode 100644 index 0000000000..c546d2be48 --- /dev/null +++ b/changelog.d/14280.bugfix @@ -0,0 +1 @@ +Fix broken outbound federation when using Python 3.7. Broke in v1.70.0rc1. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 774ecd81b6..3ad483efe0 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -536,8 +536,7 @@ class FederationSender(AbstractFederationSender): if event_entries: now = self.clock.time_msec() - last_id = next(reversed(event_ids)) - ts = event_to_received_ts[last_id] + ts = max(t for t in event_to_received_ts.values() if t) assert ts is not None synapse.metrics.event_processing_lag.labels( -- cgit 1.5.1 From 19c0e55ef7742d67cff1cb6fb7c3e862b86ea788 Mon Sep 17 00:00:00 2001 From: Ryan Miguel <1818590+renegaderyu@users.noreply.github.com> Date: Mon, 24 Oct 2022 08:55:06 -0700 Subject: Return NOT_JSON if decode fails and defer set_timeline_upper_limit ca… (#14262) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Return NOT_JSON if decode fails and defer set_timeline_upper_limit call until after check_valid_filter. Fixes #13661. Signed-off-by: Ryan Miguel . * Reword changelog --- changelog.d/14262.misc | 1 + synapse/rest/client/sync.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14262.misc (limited to 'synapse') diff --git a/changelog.d/14262.misc b/changelog.d/14262.misc new file mode 100644 index 0000000000..c1d23bc67d --- /dev/null +++ b/changelog.d/14262.misc @@ -0,0 +1 @@ +Provide a specific error code when a `/sync` request provides a filter which doesn't represent a JSON object. diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8a16459105..f2013faeb2 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -146,12 +146,12 @@ class SyncRestServlet(RestServlet): elif filter_id.startswith("{"): try: filter_object = json_decoder.decode(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.server.filter_timeline_limit - ) except Exception: - raise SynapseError(400, "Invalid filter JSON") + raise SynapseError(400, "Invalid filter JSON", errcode=Codes.NOT_JSON) self.filtering.check_valid_filter(filter_object) + set_timeline_upper_limit( + filter_object, self.hs.config.server.filter_timeline_limit + ) filter_collection = FilterCollection(self.hs, filter_object) else: try: -- cgit 1.5.1 From 581b37b5d6c1c9430108930a4fe409cf3f86332f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Oct 2022 12:07:16 -0400 Subject: Revert behavior change for bundling edits of non-message events (#14283) --- changelog.d/14283.bugfix | 1 + synapse/storage/databases/main/relations.py | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14283.bugfix (limited to 'synapse') diff --git a/changelog.d/14283.bugfix b/changelog.d/14283.bugfix new file mode 100644 index 0000000000..a80a8c0361 --- /dev/null +++ b/changelog.d/14283.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where edits to non-message events were aggregated by the homeserver. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 1de62ee9df..c022510e76 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -484,11 +484,12 @@ class RelationsWorkerStore(SQLBaseStore): the event will map to None. """ - # We only allow edits for events that have the same sender and event type. - # We can't assert these things during regular event auth so we have to do - # the checks post hoc. + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. - # Fetches latest edit that has the same type and sender as the original. + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, # so ordering by origin server ts + event ID desc will ensure we get @@ -504,6 +505,7 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? + AND edit.type = 'm.room.message' ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC """ else: @@ -522,6 +524,7 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? + AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts, edit.event_id """ -- cgit 1.5.1 From 8c94dd3a277d4e11192f98a9ca32cb6638606b66 Mon Sep 17 00:00:00 2001 From: asymmetric Date: Tue, 25 Oct 2022 11:22:55 +0200 Subject: Enable WAL for SQLite (#13897) Signed-off-by: Lorenzo Manacorda --- changelog.d/13897.feature | 1 + synapse/storage/engines/sqlite.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/13897.feature (limited to 'synapse') diff --git a/changelog.d/13897.feature b/changelog.d/13897.feature new file mode 100644 index 0000000000..d46fdf9fa5 --- /dev/null +++ b/changelog.d/13897.feature @@ -0,0 +1 @@ +Enable Write-Ahead Logging for SQLite installs. Contributed by [asymmetric](https://github.com/asymmetric). diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index faa574dbfd..14260442b6 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -88,6 +88,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): db_conn.create_function("rank", 1, _rank) db_conn.execute("PRAGMA foreign_keys = ON;") + + # Enable WAL. + # see https://www.sqlite.org/wal.html + db_conn.execute("PRAGMA journal_mode = WAL;") db_conn.commit() def is_deadlock(self, error: Exception) -> bool: -- cgit 1.5.1 From c9dffd5b330553c5803784be5bc0e2479fab79b0 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 25 Oct 2022 11:39:25 +0100 Subject: Remove unused `@lru_cache` decorator (#13595) * Remove unused `@lru_cache` decorator Spotted this working on something else. Co-authored-by: David Robertson --- changelog.d/13595.misc | 1 + synapse/util/caches/descriptors.py | 104 ---------------------------------- tests/util/caches/test_descriptors.py | 40 ++----------- 3 files changed, 5 insertions(+), 140 deletions(-) create mode 100644 changelog.d/13595.misc (limited to 'synapse') diff --git a/changelog.d/13595.misc b/changelog.d/13595.misc new file mode 100644 index 0000000000..71959a6ee7 --- /dev/null +++ b/changelog.d/13595.misc @@ -0,0 +1 @@ +Remove unused `@lru_cache` decorator. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b3c748ef44..75428d19ba 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging @@ -146,109 +145,6 @@ class _CacheDescriptorBase: ) -class _LruCachedFunction(Generic[F]): - cache: LruCache[CacheKey, Any] - __call__: F - - -def lru_cache( - *, max_entries: int = 1000, cache_context: bool = False -) -> Callable[[F], _LruCachedFunction[F]]: - """A method decorator that applies a memoizing cache around the function. - - This is more-or-less a drop-in equivalent to functools.lru_cache, although note - that the signature is slightly different. - - The main differences with functools.lru_cache are: - (a) the size of the cache can be controlled via the cache_factor mechanism - (b) the wrapped function can request a "cache_context" which provides a - callback mechanism to indicate that the result is no longer valid - (c) prometheus metrics are exposed automatically. - - The function should take zero or more arguments, which are used as the key for the - cache. Single-argument functions use that argument as the cache key; otherwise the - arguments are built into a tuple. - - Cached functions can be "chained" (i.e. a cached function can call other cached - functions and get appropriately invalidated when they called caches are - invalidated) by adding a special "cache_context" argument to the function - and passing that as a kwarg to all caches called. For example: - - @lru_cache(cache_context=True) - def foo(self, key, cache_context): - r1 = self.bar1(key, on_invalidate=cache_context.invalidate) - r2 = self.bar2(key, on_invalidate=cache_context.invalidate) - return r1 + r2 - - The wrapped function also has a 'cache' property which offers direct access to the - underlying LruCache. - """ - - def func(orig: F) -> _LruCachedFunction[F]: - desc = LruCacheDescriptor( - orig, - max_entries=max_entries, - cache_context=cache_context, - ) - return cast(_LruCachedFunction[F], desc) - - return func - - -class LruCacheDescriptor(_CacheDescriptorBase): - """Helper for @lru_cache""" - - class _Sentinel(enum.Enum): - sentinel = object() - - def __init__( - self, - orig: Callable[..., Any], - max_entries: int = 1000, - cache_context: bool = False, - ): - super().__init__( - orig, num_args=None, uncached_args=None, cache_context=cache_context - ) - self.max_entries = max_entries - - def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: - cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.name, - max_size=self.max_entries, - ) - - get_cache_key = self.cache_key_builder - sentinel = LruCacheDescriptor._Sentinel.sentinel - - @functools.wraps(self.orig) - def _wrapped(*args: Any, **kwargs: Any) -> Any: - invalidate_callback = kwargs.pop("on_invalidate", None) - callbacks = (invalidate_callback,) if invalidate_callback else () - - cache_key = get_cache_key(args, kwargs) - - ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) - if ret != sentinel: - return ret - - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - - ret2 = self.orig(obj, *args, **kwargs) - cache.set(cache_key, ret2, callbacks=callbacks) - - return ret2 - - wrapped = cast(CachedFunction, _wrapped) - wrapped.cache = cache - obj.__dict__[self.name] = wrapped - - return wrapped - - class DeferredCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 78fd7b6961..43475a307f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -28,7 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, cachedList, lru_cache +from synapse.util.caches.descriptors import cached, cachedList from tests import unittest from tests.test_utils import get_awaitable_result @@ -36,38 +36,6 @@ from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) -class LruCacheDecoratorTestCase(unittest.TestCase): - def test_base(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @lru_cache() - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - obj.mock.return_value = "fish" - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -478,10 +446,10 @@ class DescriptorTestCase(unittest.TestCase): @cached(cache_context=True) async def func2(self, key, cache_context): - return self.func3(key, on_invalidate=cache_context.invalidate) + return await self.func3(key, on_invalidate=cache_context.invalidate) - @lru_cache(cache_context=True) - def func3(self, key, cache_context): + @cached(cache_context=True) + async def func3(self, key, cache_context): self.invalidate = cache_context.invalidate return 42 -- cgit 1.5.1 From 2d0ba3f89aaf9545d81c4027500e543ec70b68a6 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 25 Oct 2022 13:38:01 +0000 Subject: Implementation for MSC3664: Pushrules for relations (#11804) --- changelog.d/11804.feature | 1 + rust/src/push/base_rules.rs | 17 +++ rust/src/push/evaluator.rs | 99 ++++++++++++- rust/src/push/mod.rs | 61 ++++++-- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 3 + synapse/push/bulk_push_rule_evaluator.py | 49 ++++++- synapse/rest/client/capabilities.py | 5 + synapse/storage/databases/main/push_rule.py | 15 +- tests/push/test_push_rule_evaluator.py | 215 +++++++++++++++++++++++++++- 10 files changed, 454 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11804.feature (limited to 'synapse') diff --git a/changelog.d/11804.feature b/changelog.d/11804.feature new file mode 100644 index 0000000000..6420393541 --- /dev/null +++ b/changelog.d/11804.feature @@ -0,0 +1 @@ +Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 63240cacfc..49802fa4eb 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -25,6 +25,7 @@ use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; +use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; @@ -114,6 +115,22 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch( + RelatedEventMatchCondition { + key: Some(Cow::Borrowed("sender")), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + rel_type: Cow::Borrowed("m.in_reply_to"), + include_fallbacks: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0365dd01dc..cedd42c54d 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -23,6 +23,7 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, + RelatedEventMatchCondition, }; lazy_static! { @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator { /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, + + /// The related events, indexed by relation type. Flattened in the same manner as + /// `flattened_keys`. + related_events_flattened: BTreeMap>, + + /// If msc3664, push rules for related events, is enabled. + related_event_match_enabled: bool, } #[pymethods] @@ -60,6 +68,8 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, + related_events_flattened: BTreeMap>, + related_event_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -72,6 +82,8 @@ impl PushRuleEvaluator { room_member_count, notification_power_levels, sender_power_level, + related_events_flattened, + related_event_match_enabled, }) } @@ -156,6 +168,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::RelatedEventMatch(event_match) => { + self.match_related_event_match(event_match, user_id)? + } KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -239,6 +254,79 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `related_event_match` condition. (MSC3664) + fn match_related_event_match( + &self, + event_match: &RelatedEventMatchCondition, + user_id: Option<&str>, + ) -> Result { + // First check if related event matching is enabled... + if !self.related_event_match_enabled { + return Ok(false); + } + + // get the related event, fail if there is none. + let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) { + event + } else { + return Ok(false); + }; + + // If we are not matching fallbacks, don't match if our special key indicating this is a + // fallback relation is not present. + if !event_match.include_fallbacks.unwrap_or(false) + && event.contains_key("im.vector.is_falling_back") + { + return Ok(false); + } + + // if we have no key, accept the event as matching, if it existed without matching any + // fields. + let key = if let Some(key) = &event_match.key { + key + } else { + return Ok(true); + }; + + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -267,8 +355,15 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = - PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 0dabfab8b8..d57800aa4a 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -267,6 +267,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatch(RelatedEventMatchCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -299,6 +301,20 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::RelatedEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RelatedEventMatchCondition { + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern_type: Option>, + pub rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_fallbacks: Option, +} + /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] @@ -391,15 +407,21 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, + msc3664_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { + pub fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3664_enabled: bool, + ) -> Self { Self { push_rules, enabled_map, + msc3664_enabled, } } @@ -414,13 +436,25 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules.iter().map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3664_enabled + && rule.rule_id == "global/override/.im.nheko.msc3664.reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } @@ -446,6 +480,17 @@ fn test_deserialize_condition() { let _: Condition = serde_json::from_str(json).unwrap(); } +#[test] +fn test_deserialize_unstable_msc3664_condition() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RelatedEventMatch(_)) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index f2a61df660..f3b6d6c933 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,7 +25,9 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... + def __init__( + self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -37,6 +39,8 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], + related_events_flattened: Mapping[str, Mapping[str, str]], + related_event_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4009add01d..d9bdd66d55 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -98,6 +98,9 @@ class ExperimentalConfig(Config): # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) + # MSC3664: Pushrules to match on related events + self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False) + # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d7795a9080..75b7e126ca 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - push_rules_invalidation_counter = Counter( "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" ) @@ -107,6 +106,8 @@ class BulkPushRuleEvaluator: self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -218,6 +219,48 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation + + Returns: + Mapping of relation type to flattened events. + """ + related_events: Dict[str, Dict[str, str]] = {} + if self._related_event_match_enabled: + related_event_id = event.content.get("m.relates_to", {}).get("event_id") + relation_type = event.content.get("m.relates_to", {}).get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = ( + event.content.get("m.relates_to", {}) + .get("m.in_reply_to", {}) + .get("event_id") + ) + + # convert replies to pseudo relations + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) + + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) + + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" + + return related_events + async def action_for_events_by_user( self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: @@ -286,6 +329,8 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) + related_events = await self._related_events(event) + # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. notification_levels = power_levels.get("notifications", {}) @@ -298,6 +343,8 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, + related_events, + self._related_event_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 4237071c61..e84dde31b1 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -77,6 +77,11 @@ class CapabilitiesRestServlet(RestServlet): "enabled": True, } + if self.config.experimental.msc3664_enabled: + response["capabilities"]["im.nheko.msc3664.related_event_match"] = { + "enabled": self.config.experimental.msc3664_enabled, + } + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 51416b2236..b6c15f29f8 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,6 +29,7 @@ from typing import ( ) from synapse.api.errors import StoreError +from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -62,7 +63,9 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], enabled_map: Dict[str, bool] + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -80,7 +83,9 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules(push_rules, enabled_map) + filtered_rules = FilteredPushRules( + push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + ) return filtered_rules @@ -160,7 +165,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map) + return _load_rules(rows, enabled_map, self.hs.config.experimental) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -219,7 +224,9 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental + ) return results diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index decf619466..fe7c145840 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: + def _get_evaluator( + self, content: JsonDict, related_events=None + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), + {} if related_events is None else related_events, + True, ) def test_display_name(self) -> None: @@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) + def test_related_event_match(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.annotation", + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + "m.annotation": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.annotation", + "pattern": "@other_user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.replace", + }, + "@other_user:test", + "display_name", + ) + ) + + def test_related_event_match_with_fallback(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.thread", + "is_falling_back": True, + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + "im.vector.is_falling_back": "", + }, + "m.thread": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": True, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": False, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + + def test_related_event_match_no_related_event(self): + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Message without related event"} + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From 9192d74b0bf2f87b00d3e106a18baa9ce27acda1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 16:25:02 +0200 Subject: Refactor OIDC tests to better mimic an actual OIDC provider. (#13910) This implements a fake OIDC server, which intercepts calls to the HTTP client. Improves accuracy of tests by covering more internal methods. One particular example was the ID token validation, which previously mocked. This uncovered an incorrect dependency: Synapse actually requires at least authlib 0.15.1, not 0.14.0. --- changelog.d/13910.misc | 1 + pyproject.toml | 2 +- synapse/handlers/oidc.py | 15 +- tests/federation/test_federation_client.py | 36 +- tests/handlers/test_oidc.py | 580 +++++++++++++---------------- tests/rest/client/test_auth.py | 32 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 136 +++---- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 325 ++++++++++++++++ 10 files changed, 747 insertions(+), 460 deletions(-) create mode 100644 changelog.d/13910.misc create mode 100644 tests/test_utils/oidc.py (limited to 'synapse') diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 0000000000..e906952aab --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. diff --git a/pyproject.toml b/pyproject.toml index 6ebac41ed1..7e0feb75aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..9759daf043 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,13 @@ class OidcProvider: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -688,9 +696,6 @@ class OidcProvider: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() @@ -715,7 +720,9 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index a538215931..51d3bb8fff 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions @@ -26,10 +23,9 @@ from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import event_injection +from tests.test_utils import FakeResponse, event_injection from tests.unittest import FederatingHomeserverTestCase @@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # We expect an outbound request to /backfill, so stub that out self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, # Mimic the other server returning our new `pulled_event` @@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase): # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the # other from "yet.another.server" self.assertEqual(backfill_num_attempts, 2) - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b7..5955410524 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ DEFAULT_CONFIG = { EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase): # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_server.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_server.get_metadata() + metadata.update(values) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_server.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_server.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_server.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_server.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_server.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase): """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase): """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase): Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 090cef5216..ebf653d018 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase): * checking that the original operation succeeds """ + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase): """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - login_resp = self.helper.login_via_oidc("username", expected_status=403) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, "username", expected_status=403 + ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) self.assertEqual( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e801ba8c8b..ff5baa9f0a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -36,7 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c249a42bb6..967d229223 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ from typing import ( Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,12 +553,28 @@ class RestHelper: return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( + clock=self.hs.get_clock(), + issuer=issuer, + ) + def login_via_oidc( self, + fake_server: FakeOidcServer, remote_user_id: str, + with_sid: bool = False, expected_status: int = 200, - ) -> JsonDict: - """Log in via OIDC + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: + """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -560,7 +586,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + fake_server, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -585,14 +614,16 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -616,6 +647,7 @@ class RestHelper: the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -625,14 +657,15 @@ class RestHelper: cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with fake_server.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -640,17 +673,21 @@ class RestHelper: # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,50 +698,37 @@ class RestHelper: Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = fake_serer.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -715,7 +739,7 @@ class RestHelper: ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -806,21 +830,3 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d..e62ebcc6a5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock: return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = attr.Factory(Headers) + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 0000000000..de134bbc89 --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,325 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcServer: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self._clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self._sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self._jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self._sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self._clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self._authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self._authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self._sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) -- cgit 1.5.1 From d902181de98399d90c46c4e4e2cf631064757941 Mon Sep 17 00:00:00 2001 From: James Salter Date: Tue, 25 Oct 2022 19:05:22 +0100 Subject: Unified search query syntax using the full-text search capabilities of the underlying DB. (#11635) Support a unified search query syntax which leverages more of the full-text search of each database supported by Synapse. Supports, with the same syntax across Postgresql 11+ and Sqlite: - quoted "search terms" - `AND`, `OR`, `-` (negation) operators - Matching words based on their stem, e.g. searches for "dog" matches documents containing "dogs". This is achieved by - If on postgresql 11+, pass the user input to `websearch_to_tsquery` - If on sqlite, manually parse the query and transform it into the sqlite-specific query syntax. Note that postgresql 10, which is close to end-of-life, falls back to using `phraseto_tsquery`, which only supports a subset of the features. Multiple terms separated by a space are implicitly ANDed. Note that: 1. There is no escaping of full-text syntax that might be supported by the database; e.g. `NOT`, `NEAR`, `*` in sqlite. This runs the risk that people might discover this as accidental functionality and depend on something we don't guarantee. 2. English text is assumed for stemming. To support other languages, either the target language needs to be known at the time of indexing the message (via room metadata, or otherwise), or a separate index for each language supported could be created. Sqlite docs: https://www.sqlite.org/fts3.html#full_text_index_queries Postgres docs: https://www.postgresql.org/docs/11/textsearch-controls.html --- changelog.d/11635.feature | 1 + synapse/storage/databases/main/search.py | 197 +++++++++++++++---- synapse/storage/engines/postgres.py | 16 ++ .../delta/73/10_update_sqlite_fts4_tokenizer.py | 62 ++++++ tests/storage/test_room_search.py | 213 +++++++++++++++++++++ 5 files changed, 454 insertions(+), 35 deletions(-) create mode 100644 changelog.d/11635.feature create mode 100644 synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py (limited to 'synapse') diff --git a/changelog.d/11635.feature b/changelog.d/11635.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/11635.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1b79acf955..a89fc54c2c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -11,10 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple +from collections import deque +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr @@ -27,7 +39,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore): count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," " room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_for_sqlite(search_term) + sql = ( "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" " FROM event_search" @@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore): "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ?" ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore): ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - return { "results": [ {"event": event_map[r["event_id"]], "rank": r["rank"]} @@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore): Each match as a dictionary. """ clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore): args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # @@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore): " CROSS JOIN events USING (event_id)" " WHERE " ) + search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ? AND " ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase] + self, search_query: str, events: List[EventBase], tsquery_func: str ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events + tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( _to_postgres_options( { "StartSel": start_sel, @@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. +@dataclass +class Phrase: + phrase: List[str] + + +class SearchToken(enum.Enum): + Not = enum.auto() + Or = enum.auto() + And = enum.auto() + + +Token = Union[str, Phrase, SearchToken] +TokenList = List[Token] + + +def _is_stop_word(word: str) -> bool: + # TODO Pull these out of the dictionary: + # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop + return word in {"the", "a", "you", "me", "and", "but"} + + +def _tokenize_query(query: str) -> TokenList: + """ + Convert the user-supplied `query` into a TokenList, which can be translated into + some DB-specific syntax. + + The following constructs are supported: + + - phrase queries using "double quotes" + - case-insensitive `or` and `and` operators + - negation of a keyword via unary `-` + - unary hyphen to denote NOT e.g. 'include -exclude' + + The following differs from websearch_to_tsquery: + + - Stop words are not removed. + - Unclosed phrases are treated differently. + + """ + tokens: TokenList = [] + + # Find phrases. + in_phrase = False + parts = deque(query.split('"')) + for i, part in enumerate(parts): + # The contents inside double quotes is treated as a phrase, a trailing + # double quote is not implied. + in_phrase = bool(i % 2) and i != (len(parts) - 1) + + # Pull out the individual words, discarding any non-word characters. + words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) + + # Phrases have simplified handling of words. + if in_phrase: + # Skip stop words. + phrase = [word for word in words if not _is_stop_word(word)] + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the phrase. + tokens.append(Phrase(phrase)) + continue + + # Otherwise, not in a phrase. + while words: + word = words.popleft() + + if word.startswith("-"): + tokens.append(SearchToken.Not) + + # If there's more word, put it back to be processed again. + word = word[1:] + if word: + words.appendleft(word) + elif word.lower() == "or": + tokens.append(SearchToken.Or) + else: + # Skip stop words. + if _is_stop_word(word): + continue + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the search term. + tokens.append(word) + + return tokens + + +def _tokens_to_sqlite_match_query(tokens: TokenList) -> str: + """ + Convert the list of tokens to a string suitable for passing to sqlite's MATCH. + Assume sqlite was compiled with enhanced query syntax. + + Ref: https://www.sqlite.org/fts3.html#full_text_index_queries """ + match_query = [] + for token in tokens: + if isinstance(token, str): + match_query.append(token) + elif isinstance(token, Phrase): + match_query.append('"' + " ".join(token.phrase) + '"') + elif token == SearchToken.Not: + # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search + # term has already been added before this. + match_query.append(" NOT ") + elif token == SearchToken.Or: + match_query.append(" OR ") + elif token == SearchToken.And: + match_query.append(" AND ") + else: + raise ValueError(f"unknown token {token}") + + return "".join(match_query) - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") +def _parse_query_for_sqlite(search_term: str) -> str: + """Takes a plain unicode string from the user and converts it into a form + that can be passed to sqllite's matchinfo(). + """ + return _tokens_to_sqlite_match_query(_tokenize_query(search_term)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index d8c0f64d9a..9bf74bbf59 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,6 +170,22 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True + @property + def tsquery_func(self) -> str: + """ + Selects a tsquery_* func to use. + + Ref: https://www.postgresql.org/docs/current/textsearch-controls.html + + Returns: + The function name. + """ + # Postgres 11 added support for websearch_to_tsquery. + assert self._version is not None + if self._version >= 110000: + return "websearch_to_tsquery" + return "plainto_tsquery" + def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py new file mode 100644 index 0000000000..3de0a709eb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py @@ -0,0 +1,62 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: + """ + Upgrade the event_search table to use the porter tokenizer if it isn't already + + Applies only for sqlite. + """ + if not isinstance(database_engine, Sqlite3Engine): + return + + # Rebuild the table event_search table with tokenize=porter configured. + cur.execute("DROP TABLE event_search") + cur.execute( + """ + CREATE VIRTUAL TABLE event_search + USING fts4 (tokenize=porter, event_id, room_id, sender, key, value ) + """ + ) + + # Re-run the background job to re-populate the event_search table. + cur.execute("SELECT MIN(stream_ordering) FROM events") + row = cur.fetchone() + min_stream_id = row[0] + + # If there are not any events, nothing to do. + if min_stream_id is None: + return + + cur.execute("SELECT MAX(stream_ordering) FROM events") + row = cur.fetchone() + max_stream_id = row[0] + + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + } + progress_json = json.dumps(progress) + + sql = """ + INSERT into background_updates (ordering, update_name, progress_json) + VALUES (?, ?, ?) + """ + + cur.execute(sql, (7310, "event_search", progress_json)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index e747c6b50e..9ddc19900a 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple, Union +from unittest.case import SkipTest +from unittest.mock import PropertyMock, patch + +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.databases.main import DataStore +from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query from synapse.storage.engines import PostgresEngine +from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.util import Clock from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import USE_POSTGRES_FOR_TESTS @@ -187,3 +198,205 @@ class EventSearchInsertionTest(HomeserverTestCase): ), ) self.assertCountEqual(values, ["hi", "2"]) + + +class MessageSearchTest(HomeserverTestCase): + """ + Check message search. + + A powerful way to check the behaviour is to run the following in Postgres >= 11: + + # SELECT websearch_to_tsquery('english', ); + + The result can be compared to the tokenized version for SQLite and Postgres < 11. + + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + PHRASE = "the quick brown fox jumps over the lazy dog" + + # Each entry is a search query, followed by either a boolean of whether it is + # in the phrase OR a tuple of booleans: whether it matches using websearch + # and using plain search. + COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + ("nope", False), + ("brown", True), + ("quick brown", True), + ("brown quick", True), + ("quick \t brown", True), + ("jump", True), + ("brown nope", False), + ('"brown quick"', (False, True)), + ('"jumps over"', True), + ('"quick fox"', (False, True)), + ("nope OR doublenope", False), + ("furphy OR fox", (True, False)), + ("fox -nope", (True, False)), + ("fox -brown", (False, True)), + ('"fox" quick', True), + ('"fox quick', True), + ('"quick brown', True), + ('" quick "', True), + ('" nope"', False), + ] + # TODO Test non-ASCII cases. + + # Case that fail on SQLite. + POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # SQLite treats NOT as a binary operator. + ("- fox", (False, True)), + ("- nope", (True, False)), + ('"-fox quick', (False, True)), + # PostgreSQL skips stop words. + ('"the quick brown"', True), + ('"over lazy"', True), + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Register a user and create a room, create some messages + self.register_user("alice", "password") + self.access_token = self.login("alice", "password") + self.room_id = self.helper.create_room_as("alice", tok=self.access_token) + + # Send the phrase as a message and check it was created + response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) + self.assertIn("event_id", response) + + def test_tokenize_query(self) -> None: + """Test the custom logic to tokenize a user's query.""" + cases = ( + ("brown", ["brown"]), + ("quick brown", ["quick", SearchToken.And, "brown"]), + ("quick \t brown", ["quick", SearchToken.And, "brown"]), + ('"brown quick"', [Phrase(["brown", "quick"])]), + ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]), + ("fox -brown", ["fox", SearchToken.Not, "brown"]), + ("- fox", [SearchToken.Not, "fox"]), + ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), + # No trailing double quoe. + ('"fox quick', ["fox", SearchToken.And, "quick"]), + ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + ('" quick "', [Phrase(["quick"])]), + ( + 'q"uick brow"n', + [ + "q", + SearchToken.And, + Phrase(["uick", "brow"]), + SearchToken.And, + "n", + ], + ), + ( + '-"quick brown"', + [SearchToken.Not, Phrase(["quick", "brown"])], + ), + ) + + for query, expected in cases: + tokenized = _tokenize_query(query) + self.assertEqual( + tokenized, expected, f"{tokenized} != {expected} for {query}" + ) + + def _check_test_cases( + self, + store: DataStore, + cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], + index=0, + ) -> None: + # Run all the test cases versus search_msgs + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_msgs([self.room_id], query, ["content.body"]) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + # Run them again versus search_rooms + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_rooms([self.room_id], query, ["content.body"], 10) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + def test_postgres_web_search_for_phrase(self): + """ + Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. + This test is skipped unless the postgres instance supports websearch_to_tsquery. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + if store.database_engine.tsquery_func != "websearch_to_tsquery": + raise SkipTest( + "Test only applies when postgres supporting websearch_to_tsquery is used as the database" + ) + + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) + + def test_postgres_non_web_search_for_phrase(self): + """ + Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't + supported by the current postgres version. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. + with patch( + "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", + new_callable=PropertyMock, + ) as supports_websearch_to_tsquery: + supports_websearch_to_tsquery.return_value = "plainto_tsquery" + self._check_test_cases( + store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 + ) + + def test_sqlite_search(self): + """ + Test sqlite searching for phrases. + """ + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, Sqlite3Engine): + raise SkipTest("Test only applies when sqlite is used as the database") + + self._check_test_cases(store, self.COMMON_CASES, index=0) -- cgit 1.5.1 From 8756d5c87efc5637da55c9e21d2a4eb2369ba693 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 26 Oct 2022 12:45:41 +0200 Subject: Save login tokens in database (#13844) * Save login tokens in database Signed-off-by: Quentin Gliech * Add upgrade notes * Track login token reuse in a Prometheus metric Signed-off-by: Quentin Gliech --- changelog.d/13844.misc | 1 + docs/upgrade.md | 9 ++ synapse/handlers/auth.py | 64 +++++++-- synapse/module_api/__init__.py | 41 +----- synapse/rest/client/login.py | 3 +- synapse/rest/client/login_token_request.py | 5 +- synapse/storage/databases/main/registration.py | 156 ++++++++++++++++++++- .../schema/main/delta/73/10login_tokens.sql | 35 +++++ synapse/util/macaroons.py | 87 +----------- tests/handlers/test_auth.py | 135 ++++++++++-------- tests/util/test_macaroons.py | 28 ---- 11 files changed, 337 insertions(+), 227 deletions(-) create mode 100644 changelog.d/13844.misc create mode 100644 synapse/storage/schema/main/delta/73/10login_tokens.sql (limited to 'synapse') diff --git a/changelog.d/13844.misc b/changelog.d/13844.misc new file mode 100644 index 0000000000..66f4414df7 --- /dev/null +++ b/changelog.d/13844.misc @@ -0,0 +1 @@ +Save login tokens in database and prevent login token reuse. diff --git a/docs/upgrade.md b/docs/upgrade.md index b81385b191..78c34d0c15 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.71.0 + +## Removal of the `generate_short_term_login_token` module API method + +As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed. + +Modules relying on it can instead use the `create_login_token` method. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f5f0e0e7a7..8b9ef25d29 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -38,6 +38,7 @@ from typing import ( import attr import bcrypt import unpaddedbase64 +from prometheus_client import Counter from twisted.internet.defer import CancelledError from twisted.web.server import Request @@ -48,6 +49,7 @@ from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, LoginError, + NotFoundError, StoreError, SynapseError, UserDeactivatedError, @@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.registration import ( + LoginTokenExpired, + LoginTokenLookupResult, + LoginTokenReused, +) from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable -from synapse.util.macaroons import LoginTokenAttributes from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email @@ -80,6 +86,12 @@ logger = logging.getLogger(__name__) INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" +invalid_login_token_counter = Counter( + "synapse_user_login_invalid_login_tokens", + "Counts the number of rejected m.login.token on /login", + ["reason"], +) + def convert_client_dict_legacy_fields_to_identifier( submission: JsonDict, @@ -883,6 +895,25 @@ class AuthHandler: return True + async def create_login_token_for_user_id( + self, + user_id: str, + duration_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + login_token = self.generate_login_token() + now = self._clock.time_msec() + expiry_ts = now + duration_ms + await self.store.add_login_token_to_user( + user_id=user_id, + token=login_token, + expiry_ts=expiry_ts, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + return login_token + async def create_refresh_token_for_user_id( self, user_id: str, @@ -1401,6 +1432,18 @@ class AuthHandler: return None return user_id + def generate_login_token(self) -> str: + """Generates an opaque string, for use as an short-term login token""" + + # we use the following format for access tokens: + # syl__ + + random_string = stringutils.random_string(20) + base = f"syl_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + def generate_access_token(self, for_user: UserID) -> str: """Generates an opaque string, for use as an access token""" @@ -1427,16 +1470,17 @@ class AuthHandler: crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" - async def validate_short_term_login_token( - self, login_token: str - ) -> LoginTokenAttributes: + async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult: try: - res = self.macaroon_gen.verify_short_term_login_token(login_token) - except Exception: - raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) + return await self.store.consume_login_token(login_token) + except LoginTokenExpired: + invalid_login_token_counter.labels("expired").inc() + except LoginTokenReused: + invalid_login_token_counter.labels("reused").inc() + except NotFoundError: + invalid_login_token_counter.labels("not found").inc() - await self.auth_blocking.check_auth_blocking(res.user_id) - return res + raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token @@ -1711,7 +1755,7 @@ class AuthHandler: ) # Create a login token - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.create_login_token_for_user_id( registered_user_id, auth_provider_id=auth_provider_id, auth_provider_session_id=auth_provider_session_id, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6a6ae208d1..30e689d00d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -771,50 +771,11 @@ class ModuleApi: auth_provider_session_id: The session ID got during login from the SSO IdP, if any. """ - # The deprecated `generate_short_term_login_token` method defaulted to an empty - # string for the `auth_provider_id` because of how the underlying macaroon was - # generated. This will change to a proper NULL-able field when the tokens get - # moved to the database. - return self._hs.get_macaroon_generator().generate_short_term_login_token( + return await self._hs.get_auth_handler().create_login_token_for_user_id( user_id, - auth_provider_id or "", - auth_provider_session_id, duration_in_ms, - ) - - def generate_short_term_login_token( - self, - user_id: str, - duration_in_ms: int = (2 * 60 * 1000), - auth_provider_id: str = "", - auth_provider_session_id: Optional[str] = None, - ) -> str: - """Generate a login token suitable for m.login.token authentication - - Added in Synapse v1.9.0. - - This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will - be removed in Synapse 1.71.0. - - Args: - user_id: gives the ID of the user that the token is for - - duration_in_ms: the time that the token will be valid for - - auth_provider_id: the ID of the SSO IdP that the user used to authenticate - to get this token, if any. This is encoded in the token so that - /login can report stats on number of successful logins by IdP. - """ - logger.warn( - "A module configured on this server uses ModuleApi.generate_short_term_login_token(), " - "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in " - "Synapse 1.71.0", - ) - return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, auth_provider_id, auth_provider_session_id, - duration_in_ms, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f554586ac3..7774f1967d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet): The body of the JSON response. """ token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) + res = await self.auth_handler.consume_login_token(token) return await self._complete_login( res.user_id, diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index 277b20fb63..43ea21d5e6 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet): self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name - self.macaroon_gen = hs.get_macaroon_generator() self.auth_handler = hs.get_auth_handler() self.token_timeout = hs.config.experimental.msc3882_token_timeout self.ui_auth = hs.config.experimental.msc3882_ui_auth @@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet): can_skip_ui_auth=False, # Don't allow skipping of UI auth ) - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.auth_handler.create_login_token_for_user_id( user_id=requester.user.to_string(), auth_provider_id="org.matrix.msc3882.login_token_request", - duration_in_ms=self.token_timeout, + duration_ms=self.token_timeout, ) return ( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 2996d6bb4d..0255295317 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + NotFoundError, + StoreError, + SynapseError, + ThreepidValidationError, +) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception): because this external id is given to an other user.""" +class LoginTokenExpired(Exception): + """Exception if the login token sent expired""" + + +class LoginTokenReused(Exception): + """Exception if the login token sent was already used""" + + @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -115,6 +129,20 @@ class RefreshTokenLookupResult: If None, the session can be refreshed indefinitely.""" +@attr.s(auto_attribs=True, frozen=True, slots=True) +class LoginTokenLookupResult: + """Result of looking up a login token.""" + + user_id: str + """The user this token belongs to.""" + + auth_provider_id: Optional[str] + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id: Optional[str] + """The session ID advertised by the SSO Identity Provider.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + async def add_login_token_to_user( + self, + user_id: str, + token: str, + expiry_ts: int, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> None: + """Adds a short-term login token for the given user. + + Args: + user_id: The user ID. + token: The new login token to add. + expiry_ts (milliseconds since the epoch): Time after which the login token + cannot be used. + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token, if any + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider, if any. + """ + await self.db_pool.simple_insert( + "login_tokens", + { + "token": token, + "user_id": user_id, + "expiry_ts": expiry_ts, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="add_login_token_to_user", + ) + + def _consume_login_token( + self, + txn: LoggingTransaction, + token: str, + ts: int, + ) -> LoginTokenLookupResult: + values = self.db_pool.simple_select_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + retcols=( + "user_id", + "expiry_ts", + "used_ts", + "auth_provider_id", + "auth_provider_session_id", + ), + allow_none=True, + ) + + if values is None: + raise NotFoundError() + + self.db_pool.simple_update_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + updatevalues={"used_ts": ts}, + ) + user_id = values["user_id"] + expiry_ts = values["expiry_ts"] + used_ts = values["used_ts"] + auth_provider_id = values["auth_provider_id"] + auth_provider_session_id = values["auth_provider_session_id"] + + # Token was already used + if used_ts is not None: + raise LoginTokenReused() + + # Token expired + if ts > int(expiry_ts): + raise LoginTokenExpired() + + return LoginTokenLookupResult( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + async def consume_login_token(self, token: str) -> LoginTokenLookupResult: + """Lookup a login token and consume it. + + Args: + token: The login token. + + Returns: + The data stored with that token, including the `user_id`. Returns `None` if + the token does not exist or if it expired. + + Raises: + NotFound if the login token was not found in database + LoginTokenExpired if the login token expired + LoginTokenReused if the login token was already used + """ + return await self.db_pool.runInteraction( + "consume_login_token", + self._consume_login_token, + token, + self._clock.time_msec(), + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( @@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Create a background job for removing expired login tokens + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS + ) + async def add_access_token_to_user( self, user_id: str, @@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): approved, ) + @wrap_as_background_process("delete_expired_login_tokens") + async def _delete_expired_login_tokens(self) -> None: + """Remove login tokens with expiry dates that have passed.""" + + def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?" + txn.execute(sql, (ts,)) + + # We keep the expired tokens for an extra 5 minutes so we can measure how many + # times a token is being used after its expiry + now = self._clock.time_msec() + await self.db_pool.runInteraction( + "delete_expired_login_tokens", + _delete_expired_login_tokens_txn, + now - (5 * 60 * 1000), + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql new file mode 100644 index 0000000000..a39b7bcece --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql @@ -0,0 +1,35 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Login tokens are short-lived tokens that are used for the m.login.token +-- login method, mainly during SSO logins +CREATE TABLE login_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expiry_ts BIGINT NOT NULL, + used_ts BIGINT, + auth_provider_id TEXT, + auth_provider_session_id TEXT +); + +-- We're sometimes querying them by their session ID we got from their IDP +CREATE INDEX login_tokens_auth_provider_idx + ON login_tokens (auth_provider_id, auth_provider_session_id); + +-- We're deleting them by their expiration time +CREATE INDEX login_tokens_expiry_time_idx + ON login_tokens (expiry_ts); + diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index df77edcce2..5df03d3ddc 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from synapse.util import Clock, stringutils -MacaroonType = Literal["access", "delete_pusher", "session", "login"] +MacaroonType = Literal["access", "delete_pusher", "session"] def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: @@ -111,19 +111,6 @@ class OidcSessionData: """The session ID of the ongoing UI Auth ("" if this is a login)""" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class LoginTokenAttributes: - """Data we store in a short-term login token""" - - user_id: str - - auth_provider_id: str - """The SSO Identity Provider that the user authenticated with, to get this token.""" - - auth_provider_session_id: Optional[str] - """The session ID advertised by the SSO Identity Provider.""" - - class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): self._clock = clock @@ -165,35 +152,6 @@ class MacaroonGenerator: macaroon.add_first_party_caveat(f"pushkey = {pushkey}") return macaroon.serialize() - def generate_short_term_login_token( - self, - user_id: str, - auth_provider_id: str, - auth_provider_session_id: Optional[str] = None, - duration_in_ms: int = (2 * 60 * 1000), - ) -> str: - """Generate a short-term login token used during SSO logins - - Args: - user_id: The user for which the token is valid. - auth_provider_id: The SSO IdP the user used. - auth_provider_session_id: The session ID got during login from the SSO IdP. - - Returns: - A signed token valid for using as a ``m.login.token`` token. - """ - now = self._clock.time_msec() - expiry = now + duration_in_ms - macaroon = self._generate_base_macaroon("login") - macaroon.add_first_party_caveat(f"user_id = {user_id}") - macaroon.add_first_party_caveat(f"time < {expiry}") - macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}") - if auth_provider_session_id is not None: - macaroon.add_first_party_caveat( - f"auth_provider_session_id = {auth_provider_session_id}" - ) - return macaroon.serialize() - def generate_oidc_session_token( self, state: str, @@ -233,49 +191,6 @@ class MacaroonGenerator: return macaroon.serialize() - def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: - """Verify a short-term-login macaroon - - Checks that the given token is a valid, unexpired short-term-login token - minted by this server. - - Args: - token: The login token to verify. - - Returns: - A set of attributes carried by this token, including the - ``user_id`` and informations about the SSO IDP used during that - login. - - Raises: - MacaroonVerificationFailedException if the verification failed - """ - macaroon = pymacaroons.Macaroon.deserialize(token) - - v = self._base_verifier("login") - v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) - satisfy_expiry(v, self._clock.time_msec) - v.verify(macaroon, self._secret_key) - - user_id = get_value_from_macaroon(macaroon, "user_id") - auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") - - auth_provider_session_id: Optional[str] = None - try: - auth_provider_session_id = get_value_from_macaroon( - macaroon, "auth_provider_session_id" - ) - except MacaroonVerificationFailedException: - pass - - return LoginTokenAttributes( - user_id=user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - def verify_guest_token(self, token: str) -> str: """Verify a guest access token macaroon diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 7106799d44..036dbbc45b 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest.mock import Mock import pymacaroons @@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, + login.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") + def token_login(self, token: str) -> Optional[str]: + body = { + "type": "m.login.token", + "token": token, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + + if channel.code == 200: + return channel.json_body["user_id"] + + return None + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + def test_login_token_gives_user_id(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + + res = self.get_success(self.auth_handler.consume_login_token(token)) self.assertEqual(self.user1, res.user_id) - self.assertEqual("", res.auth_provider_id) + self.assertEqual(None, res.auth_provider_id) - # when we advance the clock, the token should be rejected - self.reactor.advance(6) - self.get_failure( - self.auth_handler.validate_short_term_login_token(token), - AuthError, + def test_login_token_reuse_fails(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - def test_short_term_login_token_gives_auth_provider(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, auth_provider_id="my_idp" - ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) - self.assertEqual(self.user1, res.user_id) - self.assertEqual("my_idp", res.auth_provider_id) + self.get_success(self.auth_handler.consume_login_token(token)) - def test_short_term_login_token_cannot_replace_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.get_failure( + self.auth_handler.consume_login_token(token), + AuthError, ) - macaroon = pymacaroons.Macaroon.deserialize(token) - res = self.get_success( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()) + def test_login_token_expires(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - self.assertEqual(self.user1, res.user_id) - - # add another "user_id" caveat, which might allow us to override the - # user_id. - macaroon.add_first_party_caveat("user_id = b_user") + # when we advance the clock, the token should be rejected + self.reactor.advance(6) self.get_failure( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()), + self.auth_handler.consume_login_token(token), AuthError, ) + def test_login_token_gives_auth_provider(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + auth_provider_id="my_idp", + auth_provider_session_id="11-22-33-44", + duration_ms=(5 * 1000), + ) + ) + res = self.get_success(self.auth_handler.consume_login_token(token)) + self.assertEqual(self.user1, res.user_id) + self.assertEqual("my_idp", res.auth_provider_id) + self.assertEqual("11-22-33-44", res.auth_provider_session_id) + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception @@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. @@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase): ), ResourceLimitError, ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) # If in monthly active cohort self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( @@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1, device_id=None, valid_until_ms=None ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True @@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) - ) - - def _get_macaroon(self) -> pymacaroons.Macaroon: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) - return pymacaroons.Macaroon.deserialize(token) + self.assertIsNotNone(self.token_login(token)) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 32125f7bb7..40754a4711 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_short_term_login_token(self): - """Test the generation and verification of short-term login tokens""" - token = self.macaroon_generator.generate_short_term_login_token( - user_id="@user:tesths", - auth_provider_id="oidc", - auth_provider_session_id="sid", - duration_in_ms=2 * 60 * 1000, - ) - - info = self.macaroon_generator.verify_short_term_login_token(token) - self.assertEqual(info.user_id, "@user:tesths") - self.assertEqual(info.auth_provider_id, "oidc") - self.assertEqual(info.auth_provider_session_id, "sid") - - # Raises with another secret key - with self.assertRaises(MacaroonVerificationFailedException): - self.other_macaroon_generator.verify_short_term_login_token(token) - - # Wait a minute - self.reactor.pump([60]) - # Shouldn't raise - self.macaroon_generator.verify_short_term_login_token(token) - # Wait another minute - self.reactor.pump([60]) - # Should raise since it expired - with self.assertRaises(MacaroonVerificationFailedException): - self.macaroon_generator.verify_short_term_login_token(token) - def test_oidc_session_token(self): """Test the generation and verification of OIDC session cookies""" state = "arandomstate" -- cgit 1.5.1 From 04fd6221de026a74e8a3e896796d39dcf5ac6e3b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 26 Oct 2022 14:00:01 +0100 Subject: Fix incorrectly sending authentication tokens to application service as headers (#14301) --- changelog.d/14301.bugfix | 1 + synapse/appservice/api.py | 12 +++++++----- tests/appservice/test_api.py | 8 +++++--- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14301.bugfix (limited to 'synapse') diff --git a/changelog.d/14301.bugfix b/changelog.d/14301.bugfix new file mode 100644 index 0000000000..668c1f3b9c --- /dev/null +++ b/changelog.d/14301.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where access tokens would be incorrectly sent to application services as headers. Application services which were obtaining access tokens from query parameters were not affected. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index fbac4375b0..60774b240d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -123,7 +123,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if response is not None: # just an empty json object return True @@ -147,7 +147,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if response is not None: # just an empty json object return True @@ -190,7 +190,9 @@ class ApplicationServiceApi(SimpleHttpClient): b"access_token": service.hs_token, } response = await self.get_json( - uri, args=args, headers={"Authorization": f"Bearer {service.hs_token}"} + uri, + args=args, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if not isinstance(response, list): logger.warning( @@ -230,7 +232,7 @@ class ApplicationServiceApi(SimpleHttpClient): info = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if not _is_valid_3pe_metadata(info): @@ -327,7 +329,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri=uri, json_body=body, args={"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 11008ac1fb..89ee79396f 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Mapping +from typing import Any, List, Mapping, Sequence, Union from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -70,13 +70,15 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): self.request_url = None async def get_json( - url: str, args: Mapping[Any, Any], headers: Mapping[Any, Any] + url: str, + args: Mapping[Any, Any], + headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> List[JsonDict]: # Ensure the access token is passed as both a header and query arg. if not headers.get("Authorization") or not args.get(b"access_token"): raise RuntimeError("Access token not provided") - self.assertEqual(headers.get("Authorization"), f"Bearer {TOKEN}") + self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) self.assertEqual(args.get(b"access_token"), TOKEN) self.request_url = url if url == URL_USER: -- cgit 1.5.1 From 0cfbb3513152b8360155c2d75df50e06ea861fa4 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Wed, 26 Oct 2022 18:51:23 +0400 Subject: fix broken avatar checks when server_name contains a port (#13927) Fixes check_avatar_size_and_mime_type() to successfully update avatars on homeservers running on non-default ports which it would mistakenly treat as remote homeserver while validating the avatar's size and mime type. Signed-off-by: Ashish Kumar ashfame@users.noreply.github.com --- changelog.d/13927.bugfix | 1 + synapse/handlers/profile.py | 6 +++++- tests/handlers/test_profile.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13927.bugfix (limited to 'synapse') diff --git a/changelog.d/13927.bugfix b/changelog.d/13927.bugfix new file mode 100644 index 0000000000..119cd128e7 --- /dev/null +++ b/changelog.d/13927.bugfix @@ -0,0 +1 @@ +Fix a bug which prevented setting an avatar on homeservers which have an explicit port in their `server_name` and have `max_avatar_size` and/or `allowed_avatar_mimetypes` configuration. Contributed by @ashfame. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d8ff5289b5..4bf9a047a3 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -307,7 +307,11 @@ class ProfileHandler: if not self.max_avatar_size and not self.allowed_avatar_mimetypes: return True - server_name, _, media_id = parse_and_validate_mxc_uri(mxc) + host, port, media_id = parse_and_validate_mxc_uri(mxc) + if port is not None: + server_name = host + ":" + str(port) + else: + server_name = host if server_name == self.server_name: media_info = await self.store.get_local_media(media_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f88c725a42..675aa023ac 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,6 +14,8 @@ from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor import synapse.types @@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertFalse(res) + @unittest.override_config( + {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} + ) + def test_avatar_constraint_on_local_server_with_port(self): + """Test that avatar metadata is correctly fetched when the media is on a local + server and the server has an explicit port. + + (This was previously a bug) + """ + local_server_name = self.hs.config.server.server_name + media_id = "local" + local_mxc = f"mxc://{local_server_name}/{media_id}" + + # mock up the existence of the avatar file + self._setup_local_files({media_id: {"mimetype": "image/png"}}) + + # and now check that check_avatar_size_and_mime_type is happy + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc)) + ) + + @parameterized.expand([("remote",), ("remote:1234",)]) + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None: + """Test that avatar metadata is correctly fetched from a remote server""" + media_id = "remote" + remote_mxc = f"mxc://{remote_server_name}/{media_id}" + + # if the media is remote, check_avatar_size_and_mime_type just checks the + # media cache, so we don't need to instantiate a real remote server. It is + # sufficient to poke an entry into the db. + self.get_success( + self.hs.get_datastores().main.store_cached_remote_media( + media_id=media_id, + media_type="image/png", + media_length=50, + origin=remote_server_name, + time_now_ms=self.clock.time_msec(), + upload_name=None, + filesystem_id="xyz", + ) + ) + + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) + ) + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): """Stores metadata about files in the database. -- cgit 1.5.1 From 40fa8294e3096132819287dd0c6d6bd71a408902 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 26 Oct 2022 16:10:55 -0500 Subject: Refactor MSC3030 `/timestamp_to_event` to move away from our snowflake pull from `destination` pattern (#14096) 1. `federation_client.timestamp_to_event(...)` now handles all `destination` looping and uses our generic `_try_destination_list(...)` helper. 2. Consistently handling `NotRetryingDestination` and `FederationDeniedError` across `get_pdu` , backfill, and the generic `_try_destination_list` which is used for many places we use this pattern. 3. `get_pdu(...)` now returns `PulledPduInfo` so we know which `destination` we ended up pulling the PDU from --- changelog.d/14096.misc | 1 + synapse/federation/federation_client.py | 130 ++++++++++++++++++++++++----- synapse/handlers/federation.py | 15 ++-- synapse/handlers/federation_event.py | 31 ++++--- synapse/handlers/room.py | 126 +++++++++++----------------- synapse/util/retryutils.py | 2 +- tests/federation/test_federation_client.py | 12 ++- 7 files changed, 191 insertions(+), 126 deletions(-) create mode 100644 changelog.d/14096.misc (limited to 'synapse') diff --git a/changelog.d/14096.misc b/changelog.d/14096.misc new file mode 100644 index 0000000000..2c07dc673b --- /dev/null +++ b/changelog.d/14096.misc @@ -0,0 +1 @@ +Refactor [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint to loop over federation destinations with standard pattern and error handling. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b220ab43fc..fa225182be 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -80,6 +80,18 @@ PDU_RETRY_TIME_MS = 1 * 60 * 1000 T = TypeVar("T") +@attr.s(frozen=True, slots=True, auto_attribs=True) +class PulledPduInfo: + """ + A result object that stores the PDU and info about it like which homeserver we + pulled it from (`pull_origin`) + """ + + pdu: EventBase + # Which homeserver we pulled the PDU from + pull_origin: str + + class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse @@ -114,7 +126,9 @@ class FederationClient(FederationBase): self.hostname = hs.hostname self.signing_key = hs.signing_key - self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache( + # Cache mapping `event_id` to a tuple of the event itself and the `pull_origin` + # (which server we pulled the event from) + self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache( cache_name="get_pdu_cache", clock=self._clock, max_len=1000, @@ -352,11 +366,11 @@ class FederationClient(FederationBase): @tag_args async def get_pdu( self, - destinations: Iterable[str], + destinations: Collection[str], event_id: str, room_version: RoomVersion, timeout: Optional[int] = None, - ) -> Optional[EventBase]: + ) -> Optional[PulledPduInfo]: """Requests the PDU with given origin and ID from the remote home servers. @@ -371,11 +385,11 @@ class FederationClient(FederationBase): moving to the next destination. None indicates no timeout. Returns: - The requested PDU, or None if we were unable to find it. + The requested PDU wrapped in `PulledPduInfo`, or None if we were unable to find it. """ logger.debug( - "get_pdu: event_id=%s from destinations=%s", event_id, destinations + "get_pdu(event_id=%s): from destinations=%s", event_id, destinations ) # TODO: Rate limit the number of times we try and get the same event. @@ -384,19 +398,25 @@ class FederationClient(FederationBase): # it gets persisted to the database), so we cache the results of the lookup. # Note that this is separate to the regular get_event cache which caches # events once they have been persisted. - event = self._get_pdu_cache.get(event_id) + get_pdu_cache_entry = self._get_pdu_cache.get(event_id) + event = None + pull_origin = None + if get_pdu_cache_entry: + event, pull_origin = get_pdu_cache_entry # If we don't see the event in the cache, go try to fetch it from the # provided remote federated destinations - if not event: + else: pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + # TODO: We can probably refactor this to use `_try_destination_list` for destination in destinations: now = self._clock.time_msec() last_attempt = pdu_attempts.get(destination, 0) if last_attempt + PDU_RETRY_TIME_MS > now: logger.debug( - "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + "get_pdu(event_id=%s): skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + event_id, destination, last_attempt, PDU_RETRY_TIME_MS, @@ -411,43 +431,48 @@ class FederationClient(FederationBase): room_version=room_version, timeout=timeout, ) + pull_origin = destination pdu_attempts[destination] = now if event: # Prime the cache - self._get_pdu_cache[event.event_id] = event + self._get_pdu_cache[event.event_id] = (event, pull_origin) # Now that we have an event, we can break out of this # loop and stop asking other destinations. break + except NotRetryingDestination as e: + logger.info("get_pdu(event_id=%s): %s", event_id, e) + continue + except FederationDeniedError: + logger.info( + "get_pdu(event_id=%s): Not attempting to fetch PDU from %s because the homeserver is not on our federation whitelist", + event_id, + destination, + ) + continue except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", + "get_pdu(event_id=%s): Failed to get PDU from %s because %s", event_id, destination, e, ) continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue - except FederationDeniedError as e: - logger.info(str(e)) - continue except Exception as e: pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", + "get_pdu(event_id=): Failed to get PDU from %s because %s", event_id, destination, e, ) continue - if not event: + if not event or not pull_origin: return None # `event` now refers to an object stored in `get_pdu_cache`. Our @@ -459,7 +484,7 @@ class FederationClient(FederationBase): event.room_version, ) - return event_copy + return PulledPduInfo(event_copy, pull_origin) @trace @tag_args @@ -699,12 +724,14 @@ class FederationClient(FederationBase): pdu_origin = get_domain_from_id(pdu.sender) if not res and pdu_origin != origin: try: - res = await self.get_pdu( + pulled_pdu_info = await self.get_pdu( destinations=[pdu_origin], event_id=pdu.event_id, room_version=room_version, timeout=10000, ) + if pulled_pdu_info is not None: + res = pulled_pdu_info.pdu except SynapseError: pass @@ -806,6 +833,7 @@ class FederationClient(FederationBase): ) for destination in destinations: + # We don't want to ask our own server for information we don't have if destination == self.server_name: continue @@ -814,9 +842,21 @@ class FederationClient(FederationBase): except ( RequestSendFailed, InvalidResponseError, - NotRetryingDestination, ) as e: logger.warning("Failed to %s via %s: %s", description, destination, e) + # Skip to the next homeserver in the list to try. + continue + except NotRetryingDestination as e: + logger.info("%s: %s", description, e) + continue + except FederationDeniedError: + logger.info( + "%s: Not attempting to %s from %s because the homeserver is not on our federation whitelist", + description, + description, + destination, + ) + continue except UnsupportedRoomVersionError: raise except HttpResponseException as e: @@ -1609,6 +1649,54 @@ class FederationClient(FederationBase): return result async def timestamp_to_event( + self, *, destinations: List[str], room_id: str, timestamp: int, direction: str + ) -> Optional["TimestampToEventResponse"]: + """ + Calls each remote federating server from `destinations` asking for their closest + event to the given timestamp in the given direction until we get a response. + Also validates the response to always return the expected keys or raises an + error. + + Args: + destinations: The domains of homeservers to try fetching from + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A parsed TimestampToEventResponse including the closest event_id + and origin_server_ts or None if no destination has a response. + """ + + async def _timestamp_to_event_from_destination( + destination: str, + ) -> TimestampToEventResponse: + return await self._timestamp_to_event_from_destination( + destination, room_id, timestamp, direction + ) + + try: + # Loop through each homeserver candidate until we get a succesful response + timestamp_to_event_response = await self._try_destination_list( + "timestamp_to_event", + destinations, + # TODO: The requested timestamp may lie in a part of the + # event graph that the remote server *also* didn't have, + # in which case they will have returned another event + # which may be nowhere near the requested timestamp. In + # the future, we may need to reconcile that gap and ask + # other homeservers, and/or extend `/timestamp_to_event` + # to return events on *both* sides of the timestamp to + # help reconcile the gap faster. + _timestamp_to_event_from_destination, + ) + return timestamp_to_event_response + except SynapseError: + return None + + async def _timestamp_to_event_from_destination( self, destination: str, room_id: str, timestamp: int, direction: str ) -> "TimestampToEventResponse": """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4fbc79a6cb..5fc3b8bc8c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -442,6 +442,15 @@ class FederationHandler: # appropriate stuff. # TODO: We can probably do something more intelligent here. return True + except NotRetryingDestination as e: + logger.info("_maybe_backfill_inner: %s", e) + continue + except FederationDeniedError: + logger.info( + "_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist", + dom, + ) + continue except (SynapseError, InvalidResponseError) as e: logger.info("Failed to backfill from %s because %s", dom, e) continue @@ -477,15 +486,9 @@ class FederationHandler: logger.info("Failed to backfill from %s because %s", dom, e) continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue except RequestSendFailed as e: logger.info("Failed to get backfill from %s because %s", dom, e) continue - except FederationDeniedError as e: - logger.info(e) - continue except Exception as e: logger.exception("Failed to backfill from %s because %s", dom, e) continue diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7da6316a82..9ca5df7c78 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -58,7 +58,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.federation.federation_client import InvalidResponseError +from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import ( SynapseTags, @@ -1517,8 +1517,8 @@ class FederationEventHandler: ) async def backfill_event_id( - self, destination: str, room_id: str, event_id: str - ) -> EventBase: + self, destinations: List[str], room_id: str, event_id: str + ) -> PulledPduInfo: """Backfill a single event and persist it as a non-outlier which means we also pull in all of the state and auth events necessary for it. @@ -1530,24 +1530,21 @@ class FederationEventHandler: Raises: FederationError if we are unable to find the event from the destination """ - logger.info( - "backfill_event_id: event_id=%s from destination=%s", event_id, destination - ) + logger.info("backfill_event_id: event_id=%s", event_id) room_version = await self._store.get_room_version(room_id) - event_from_response = await self._federation_client.get_pdu( - [destination], + pulled_pdu_info = await self._federation_client.get_pdu( + destinations, event_id, room_version, ) - if not event_from_response: + if not pulled_pdu_info: raise FederationError( "ERROR", 404, - "Unable to find event_id=%s from destination=%s to backfill." - % (event_id, destination), + f"Unable to find event_id={event_id} from remote servers to backfill.", affected=event_id, ) @@ -1555,13 +1552,13 @@ class FederationEventHandler: # and auth events to de-outlier it. This also sets up the necessary # `state_groups` for the event. await self._process_pulled_events( - destination, - [event_from_response], + pulled_pdu_info.pull_origin, + [pulled_pdu_info.pdu], # Prevent notifications going to clients backfilled=True, ) - return event_from_response + return pulled_pdu_info @trace @tag_args @@ -1584,19 +1581,19 @@ class FederationEventHandler: async def get_event(event_id: str) -> None: with nested_logging_context(event_id): try: - event = await self._federation_client.get_pdu( + pulled_pdu_info = await self._federation_client.get_pdu( [destination], event_id, room_version, ) - if event is None: + if pulled_pdu_info is None: logger.warning( "Server %s didn't return event %s", destination, event_id, ) return - events.append(event) + events.append(pulled_pdu_info.pdu) except Exception as e: logger.warning( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index cc1e5c8f97..de97886ea9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -49,7 +49,6 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -60,7 +59,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_and_fixup_power_levels_contents -from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin @@ -1472,7 +1470,12 @@ class TimestampLookupHandler: Raises: SynapseError if unable to find any event locally in the given direction """ - + logger.debug( + "get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...", + room_id, + timestamp, + direction, + ) local_event_id = await self.store.get_event_id_for_timestamp( room_id, timestamp, direction ) @@ -1524,85 +1527,54 @@ class TimestampLookupHandler: ) ) - # Loop through each homeserver candidate until we get a succesful response - for domain in likely_domains: - # We don't want to ask our own server for information we don't have - if domain == self.server_name: - continue + remote_response = await self.federation_client.timestamp_to_event( + destinations=likely_domains, + room_id=room_id, + timestamp=timestamp, + direction=direction, + ) + if remote_response is not None: + logger.debug( + "get_event_for_timestamp: remote_response=%s", + remote_response, + ) - try: - remote_response = await self.federation_client.timestamp_to_event( - domain, room_id, timestamp, direction - ) - logger.debug( - "get_event_for_timestamp: response from domain(%s)=%s", - domain, - remote_response, - ) + remote_event_id = remote_response.event_id + remote_origin_server_ts = remote_response.origin_server_ts - remote_event_id = remote_response.event_id - remote_origin_server_ts = remote_response.origin_server_ts - - # Backfill this event so we can get a pagination token for - # it with `/context` and paginate `/messages` from this - # point. - # - # TODO: The requested timestamp may lie in a part of the - # event graph that the remote server *also* didn't have, - # in which case they will have returned another event - # which may be nowhere near the requested timestamp. In - # the future, we may need to reconcile that gap and ask - # other homeservers, and/or extend `/timestamp_to_event` - # to return events on *both* sides of the timestamp to - # help reconcile the gap faster. - remote_event = ( - await self.federation_event_handler.backfill_event_id( - domain, room_id, remote_event_id - ) - ) + # Backfill this event so we can get a pagination token for + # it with `/context` and paginate `/messages` from this + # point. + pulled_pdu_info = await self.federation_event_handler.backfill_event_id( + likely_domains, room_id, remote_event_id + ) + remote_event = pulled_pdu_info.pdu - # XXX: When we see that the remote server is not trustworthy, - # maybe we should not ask them first in the future. - if remote_origin_server_ts != remote_event.origin_server_ts: - logger.info( - "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", - domain, - remote_event_id, - remote_origin_server_ts, - remote_event.origin_server_ts, - ) - - # Only return the remote event if it's closer than the local event - if not local_event or ( - abs(remote_event.origin_server_ts - timestamp) - < abs(local_event.origin_server_ts - timestamp) - ): - logger.info( - "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", - remote_event_id, - remote_event.origin_server_ts, - timestamp, - local_event.event_id if local_event else None, - local_event.origin_server_ts if local_event else None, - ) - return remote_event_id, remote_origin_server_ts - except (HttpResponseException, InvalidResponseError) as ex: - # Let's not put a high priority on some other homeserver - # failing to respond or giving a random response - logger.debug( - "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", - domain, - type(ex).__name__, - ex, - ex.args, + # XXX: When we see that the remote server is not trustworthy, + # maybe we should not ask them first in the future. + if remote_origin_server_ts != remote_event.origin_server_ts: + logger.info( + "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", + pulled_pdu_info.pull_origin, + remote_event_id, + remote_origin_server_ts, + remote_event.origin_server_ts, ) - except Exception: - # But we do want to see some exceptions in our code - logger.warning( - "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception", - domain, - exc_info=True, + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(remote_event.origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + logger.info( + "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", + remote_event_id, + remote_event.origin_server_ts, + timestamp, + local_event.event_id if local_event else None, + local_event.origin_server_ts if local_event else None, ) + return remote_event_id, remote_origin_server_ts # To appease mypy, we have to add both of these conditions to check for # `None`. We only expect `local_event` to be `None` when diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index d0a69ff843..dcc037b982 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -51,7 +51,7 @@ class NotRetryingDestination(Exception): destination: the domain in question """ - msg = "Not retrying server %s." % (destination,) + msg = f"Not retrying server {destination} because we tried it recently retry_last_ts={retry_last_ts} and we won't check for another retry_interval={retry_interval}ms." super().__init__(msg) self.retry_last_ts = retry_last_ts diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 51d3bb8fff..e67f405826 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -142,14 +142,14 @@ class FederationClientTest(FederatingHomeserverTestCase): def test_get_pdu_returns_nothing_when_event_does_not_exist(self): """No event should be returned when the event does not exist""" - remote_pdu = self.get_success( + pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], "event_should_not_exist", RoomVersions.V9, ) ) - self.assertEqual(remote_pdu, None) + self.assertEqual(pulled_pdu_info, None) def test_get_pdu(self): """Test to make sure an event is returned by `get_pdu()`""" @@ -169,13 +169,15 @@ class FederationClientTest(FederatingHomeserverTestCase): remote_pdu.internal_metadata.outlier = True # Get the event again. This time it should read it from cache. - remote_pdu2 = self.get_success( + pulled_pdu_info2 = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], remote_pdu.event_id, RoomVersions.V9, ) ) + self.assertIsNotNone(pulled_pdu_info2) + remote_pdu2 = pulled_pdu_info2.pdu # Sanity check that we are working against the same event self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id) @@ -215,13 +217,15 @@ class FederationClientTest(FederatingHomeserverTestCase): ) ) - remote_pdu = self.get_success( + pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], "event_id", RoomVersions.V9, ) ) + self.assertIsNotNone(pulled_pdu_info) + remote_pdu = pulled_pdu_info.pdu # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( -- cgit 1.5.1 From cbe01ccc3f9c09a0a7233f90200fbcb8ae5245cf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 27 Oct 2022 10:52:23 +0100 Subject: Reject history insertion during partial joins (#14291) --- changelog.d/14291.bugfix | 1 + synapse/rest/client/room_batch.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/14291.bugfix (limited to 'synapse') diff --git a/changelog.d/14291.bugfix b/changelog.d/14291.bugfix new file mode 100644 index 0000000000..bac5065e94 --- /dev/null +++ b/changelog.d/14291.bugfix @@ -0,0 +1 @@ +Prevent history insertion ([MSC2716](https://github.com/matrix-org/matrix-spec-proposals/pull/2716)) during an partial join ([MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706)). diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index dd91dabedd..10be4a781b 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -108,6 +108,13 @@ class RoomBatchSendEventRestServlet(RestServlet): errcode=Codes.MISSING_PARAM, ) + if await self.store.is_partial_state_room(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Cannot insert history batches until we have fully joined the room", + errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, + ) + # Verify the batch_id_from_query corresponds to an actual insertion event # and have the batch connected. if batch_id_from_query: -- cgit 1.5.1 From 4dc05f30193935224103e8772b1bbc15293e5cb6 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 27 Oct 2022 14:16:00 +0200 Subject: Fix presence bug introduced in 1.64 by #13313 (#14243) * Fix presence bug introduced in 1.64 by #13313 Signed-off-by: Mathieu Velten * Add changelog * Add DISTINCT * Apply suggestions from code review Signed-off-by: Mathieu Velten --- changelog.d/14243.bugfix | 1 + synapse/storage/databases/main/roommember.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14243.bugfix (limited to 'synapse') diff --git a/changelog.d/14243.bugfix b/changelog.d/14243.bugfix new file mode 100644 index 0000000000..ac0b21c2c5 --- /dev/null +++ b/changelog.d/14243.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.64.0 where presence updates could be missing from `/sync` responses. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 32e1e983a5..ab708b0ba5 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -742,7 +742,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # user and the set of other users, and then checking if there is any # overlap. sql = f""" - SELECT b.state_key + SELECT DISTINCT b.state_key FROM ( SELECT room_id FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ? @@ -751,7 +751,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): SELECT room_id, state_key FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND {clause} ) AS b using (room_id) - LIMIT 1 """ txn.execute(sql, (user_id, *args)) -- cgit 1.5.1 From 1357ae869f279a3f0855c1b1c2750eca2887928e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:39:47 +0200 Subject: Add workers settings to configuration manual (#14086) * Add workers settings to configuration manual * Update `pusher_instances` * update url to python logger * update headlines * update links after headline change * remove link from `daemon process` There is no docs in Synapse for this * extend example for `federation_sender_instances` and `pusher_instances` * more infos about stream writers * add link to DAG * update `pusher_instances` * update `worker_listeners` * update `stream_writers` * Update `worker_name` Co-authored-by: David Robertson --- changelog.d/14086.doc | 1 + docs/sample_log_config.yaml | 2 +- docs/usage/configuration/config_documentation.md | 268 +++++++++++++++++++---- docs/workers.md | 100 ++++++--- synapse/config/logger.py | 2 +- 5 files changed, 291 insertions(+), 82 deletions(-) create mode 100644 changelog.d/14086.doc (limited to 'synapse') diff --git a/changelog.d/14086.doc b/changelog.d/14086.doc new file mode 100644 index 0000000000..5b4b938759 --- /dev/null +++ b/changelog.d/14086.doc @@ -0,0 +1 @@ +Add workers settings to [configuration manual](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#individual-worker-configuration). \ No newline at end of file diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 3065a0e2d9..6339160d00 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -6,7 +6,7 @@ # Synapse also supports structured logging for machine readable logs which can # be ingested by ELK stacks. See [2] for details. # -# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [1]: https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema # [2]: https://matrix-org.github.io/synapse/latest/structured_logging.html version: 1 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index d81eda52c1..fb5eb42c52 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -99,7 +99,7 @@ modules: config: {} ``` --- -## Server ## +## Server Define your homeserver name and other base options. @@ -159,7 +159,7 @@ including _matrix/...). This is the same URL a user might enter into the 'Custom Homeserver URL' field on their client. If you use Synapse with a reverse proxy, this should be the URL to reach Synapse via the proxy. Otherwise, it should be the URL to reach Synapse's client HTTP listener (see -'listeners' below). +['listeners'](#listeners) below). Defaults to `https:///`. @@ -570,7 +570,7 @@ Example configuration: delete_stale_devices_after: 1y ``` -## Homeserver blocking ## +## Homeserver blocking Useful options for Synapse admins. --- @@ -922,7 +922,7 @@ retention: interval: 1d ``` --- -## TLS ## +## TLS Options related to TLS. @@ -1012,7 +1012,7 @@ federation_custom_ca_list: - myCA3.pem ``` --- -## Federation ## +## Federation Options related to federation. @@ -1071,7 +1071,7 @@ Example configuration: allow_device_name_lookup_over_federation: true ``` --- -## Caching ## +## Caching Options related to caching. @@ -1185,7 +1185,7 @@ file in Synapse's `contrib` directory, you can send a `SIGHUP` signal by using `systemctl reload matrix-synapse`. --- -## Database ## +## Database Config options related to database settings. --- @@ -1332,20 +1332,21 @@ databases: cp_max: 10 ``` --- -## Logging ## +## Logging Config options related to logging. --- ### `log_config` -This option specifies a yaml python logging config file as described [here](https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema). +This option specifies a yaml python logging config file as described +[here](https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema). Example configuration: ```yaml log_config: "CONFDIR/SERVERNAME.log.config" ``` --- -## Ratelimiting ## +## Ratelimiting Options related to ratelimiting in Synapse. Each ratelimiting configuration is made of two parameters: @@ -1576,7 +1577,7 @@ Example configuration: federation_rr_transactions_per_room_per_second: 40 ``` --- -## Media Store ## +## Media Store Config options related to Synapse's media store. --- @@ -1766,7 +1767,7 @@ url_preview_ip_range_blacklist: - 'ff00::/8' - 'fec0::/10' ``` ----- +--- ### `url_preview_ip_range_whitelist` This option sets a list of IP address CIDR ranges that the URL preview spider is allowed @@ -1860,7 +1861,7 @@ Example configuration: - 'fr;q=0.8' - '*;q=0.7' ``` ----- +--- ### `oembed` oEmbed allows for easier embedding content from a website. It can be @@ -1877,7 +1878,7 @@ oembed: - oembed/my_providers.json ``` --- -## Captcha ## +## Captcha See [here](../../CAPTCHA_SETUP.md) for full details on setting up captcha. @@ -1926,7 +1927,7 @@ Example configuration: recaptcha_siteverify_api: "https://my.recaptcha.site" ``` --- -## TURN ## +## TURN Options related to adding a TURN server to Synapse. --- @@ -1947,7 +1948,7 @@ Example configuration: ```yaml turn_shared_secret: "YOUR_SHARED_SECRET" ``` ----- +--- ### `turn_username` and `turn_password` The Username and password if the TURN server needs them and does not use a token. @@ -2366,7 +2367,7 @@ Example configuration: ```yaml session_lifetime: 24h ``` ----- +--- ### `refresh_access_token_lifetime` Time that an access token remains valid for, if the session is using refresh tokens. @@ -2422,7 +2423,7 @@ nonrefreshable_access_token_lifetime: 24h ``` --- -## Metrics ### +## Metrics Config options related to metrics. --- @@ -2519,7 +2520,7 @@ Example configuration: report_stats_endpoint: https://example.com/report-usage-stats/push ``` --- -## API Configuration ## +## API Configuration Config settings related to the client/server API --- @@ -2619,7 +2620,7 @@ Example configuration: form_secret: ``` --- -## Signing Keys ## +## Signing Keys Config options relating to signing keys --- @@ -2728,7 +2729,7 @@ Example configuration: key_server_signing_keys_path: "key_server_signing_keys.key" ``` --- -## Single sign-on integration ## +## Single sign-on integration The following settings can be used to make Synapse use a single sign-on provider for authentication, instead of its internal password database. @@ -3348,7 +3349,7 @@ email: email_validation: "[%(server_name)s] Validate your email" ``` --- -## Push ## +## Push Configuration settings related to push notifications --- @@ -3381,7 +3382,7 @@ push: group_unread_count_by_room: false ``` --- -## Rooms ## +## Rooms Config options relating to rooms. --- @@ -3627,7 +3628,7 @@ default_power_level_content_override: ``` --- -## Opentracing ## +## Opentracing Configuration options related to Opentracing support. --- @@ -3670,14 +3671,71 @@ opentracing: false ``` --- -## Workers ## -Configuration options related to workers. +## Coordinating workers +Configuration options related to workers which belong in the main config file +(usually called `homeserver.yaml`). +A Synapse deployment can scale horizontally by running multiple Synapse processes +called _workers_. Incoming requests are distributed between workers to handle higher +loads. Some workers are privileged and can accept requests from other workers. + +As a result, the worker configuration is divided into two parts. + +1. The first part (in this section of the manual) defines which shardable tasks + are delegated to privileged workers. This allows unprivileged workers to make + request a privileged worker to act on their behalf. +1. [The second part](#individual-worker-configuration) + controls the behaviour of individual workers in isolation. + +For guidance on setting up workers, see the [worker documentation](../../workers.md). + +--- +### `worker_replication_secret` + +A shared secret used by the replication APIs on the main process to authenticate +HTTP requests from workers. + +The default, this value is omitted (equivalently `null`), which means that +traffic between the workers and the main process is not authenticated. + +Example configuration: +```yaml +worker_replication_secret: "secret_secret" +``` +--- +### `start_pushers` + +Controls sending of push notifications on the main process. Set to `false` +if using a [pusher worker](../../workers.md#synapseapppusher). Defaults to `true`. + +Example configuration: +```yaml +start_pushers: false +``` +--- +### `pusher_instances` + +It is possible to run multiple [pusher workers](../../workers.md#synapseapppusher), +in which case the work is balanced across them. Use this setting to list the pushers by +[`worker_name`](#worker_name). Ensure the main process and all pusher workers are +restarted after changing this option. +If no or only one pusher worker is configured, this setting is not necessary. +The main process will send out push notifications by default if you do not disable +it by setting [`start_pushers: false`](#start_pushers). + +Example configuration: +```yaml +start_pushers: false +pusher_instances: + - pusher_worker1 + - pusher_worker2 +``` --- ### `send_federation` Controls sending of outbound federation transactions on the main process. -Set to false if using a federation sender worker. Defaults to true. +Set to `false` if using a [federation sender worker](../../workers.md#synapseappfederation_sender). +Defaults to `true`. Example configuration: ```yaml @@ -3686,8 +3744,9 @@ send_federation: false --- ### `federation_sender_instances` -It is possible to run multiple federation sender workers, in which case the -work is balanced across them. Use this setting to list the senders. +It is possible to run multiple +[federation sender worker](../../workers.md#synapseappfederation_sender), in which +case the work is balanced across them. Use this setting to list the senders. This configuration setting must be shared between all federation sender workers, and if changed all federation sender workers must be stopped at the same time and then @@ -3696,14 +3755,19 @@ events may be dropped). Example configuration: ```yaml +send_federation: false federation_sender_instances: - federation_sender1 ``` --- ### `instance_map` -When using workers this should be a map from worker name to the +When using workers this should be a map from [`worker_name`](#worker_name) to the HTTP replication listener of the worker, if configured. +Each worker declared under [`stream_writers`](../../workers.md#stream-writers) needs +a HTTP replication listener, and that listener should be included in the `instance_map`. +(The main process also needs an HTTP replication listener, but it should not be +listed in the `instance_map`.) Example configuration: ```yaml @@ -3716,8 +3780,11 @@ instance_map: ### `stream_writers` Experimental: When using workers you can define which workers should -handle event persistence and typing notifications. Any worker -specified here must also be in the `instance_map`. +handle writing to streams such as event persistence and typing notifications. +Any worker specified here must also be in the [`instance_map`](#instance_map). + +See the list of available streams in the +[worker documentation](../../workers.md#stream-writers). Example configuration: ```yaml @@ -3728,29 +3795,18 @@ stream_writers: --- ### `run_background_tasks_on` -The worker that is used to run background tasks (e.g. cleaning up expired -data). If not provided this defaults to the main process. +The [worker](../../workers.md#background-tasks) that is used to run +background tasks (e.g. cleaning up expired data). If not provided this +defaults to the main process. Example configuration: ```yaml run_background_tasks_on: worker1 ``` --- -### `worker_replication_secret` - -A shared secret used by the replication APIs to authenticate HTTP requests -from workers. - -By default this is unused and traffic is not authenticated. - -Example configuration: -```yaml -worker_replication_secret: "secret_secret" -``` ### `redis` -Configuration for Redis when using workers. This *must* be enabled when -using workers (unless using old style direct TCP configuration). +Configuration for Redis when using workers. This *must* be enabled when using workers. This setting has the following sub-options: * `enabled`: whether to use Redis support. Defaults to false. * `host` and `port`: Optional host and port to use to connect to redis. Defaults to @@ -3765,7 +3821,123 @@ redis: port: 6379 password: ``` -## Background Updates ## +--- +## Individual worker configuration +These options configure an individual worker, in its worker configuration file. +They should be not be provided when configuring the main process. + +Note also the configuration above for +[coordinating a cluster of workers](#coordinating-workers). + +For guidance on setting up workers, see the [worker documentation](../../workers.md). + +--- +### `worker_app` + +The type of worker. The currently available worker applications are listed +in [worker documentation](../../workers.md#available-worker-applications). + +The most common worker is the +[`synapse.app.generic_worker`](../../workers.md#synapseappgeneric_worker). + +Example configuration: +```yaml +worker_app: synapse.app.generic_worker +``` +--- +### `worker_name` + +A unique name for the worker. The worker needs a name to be addressed in +further parameters and identification in log files. We strongly recommend +giving each worker a unique `worker_name`. + +Example configuration: +```yaml +worker_name: generic_worker1 +``` +--- +### `worker_replication_host` + +The HTTP replication endpoint that it should talk to on the main Synapse process. +The main Synapse process defines this with a `replication` resource in +[`listeners` option](#listeners). + +Example configuration: +```yaml +worker_replication_host: 127.0.0.1 +``` +--- +### `worker_replication_http_port` + +The HTTP replication port that it should talk to on the main Synapse process. +The main Synapse process defines this with a `replication` resource in +[`listeners` option](#listeners). + +Example configuration: +```yaml +worker_replication_http_port: 9093 +``` +--- +### `worker_listeners` + +A worker can handle HTTP requests. To do so, a `worker_listeners` option +must be declared, in the same way as the [`listeners` option](#listeners) +in the shared config. + +Workers declared in [`stream_writers`](#stream_writers) will need to include a +`replication` listener here, in order to accept internal HTTP requests from +other workers. + +Example configuration: +```yaml +worker_listeners: + - type: http + port: 8083 + resources: + - names: [client, federation] +``` +--- +### `worker_daemonize` + +Specifies whether the worker should be started as a daemon process. +If Synapse is being managed by [systemd](../../systemd-with-workers/README.md), this option +must be omitted or set to `false`. + +Defaults to `false`. + +Example configuration: +```yaml +worker_daemonize: true +``` +--- +### `worker_pid_file` + +When running a worker as a daemon, we need a place to store the +[PID](https://en.wikipedia.org/wiki/Process_identifier) of the worker. +This option defines the location of that "pid file". + +This option is required if `worker_daemonize` is `true` and ignored +otherwise. It has no default. + +See also the [`pid_file` option](#pid_file) option for the main Synapse process. + +Example configuration: +```yaml +worker_pid_file: DATADIR/generic_worker1.pid +``` +--- +### `worker_log_config` + +This option specifies a yaml python logging config file as described +[here](https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema). +See also the [`log_config` option](#log_config) option for the main Synapse process. + +Example configuration: +```yaml +worker_log_config: /etc/matrix-synapse/generic-worker-log.yaml +``` +--- +## Background Updates Configuration settings related to background updates. --- diff --git a/docs/workers.md b/docs/workers.md index c27b3f8bd5..5e1b9ba220 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -88,10 +88,12 @@ shared configuration file. ### Shared configuration Normally, only a couple of changes are needed to make an existing configuration -file suitable for use with workers. First, you need to enable an "HTTP replication -listener" for the main process; and secondly, you need to enable redis-based -replication. Optionally, a shared secret can be used to authenticate HTTP -traffic between workers. For example: +file suitable for use with workers. First, you need to enable an +["HTTP replication listener"](usage/configuration/config_documentation.md#listeners) +for the main process; and secondly, you need to enable +[redis-based replication](usage/configuration/config_documentation.md#redis). +Optionally, a [shared secret](usage/configuration/config_documentation.md#worker_replication_secret) +can be used to authenticate HTTP traffic between workers. For example: ```yaml # extend the existing `listeners` section. This defines the ports that the @@ -111,25 +113,28 @@ redis: enabled: true ``` -See the [configuration manual](usage/configuration/config_documentation.html) for the full documentation of each option. +See the [configuration manual](usage/configuration/config_documentation.md) +for the full documentation of each option. Under **no circumstances** should the replication listener be exposed to the public internet; replication traffic is: * always unencrypted -* unauthenticated, unless `worker_replication_secret` is configured +* unauthenticated, unless [`worker_replication_secret`](usage/configuration/config_documentation.md#worker_replication_secret) + is configured ### Worker configuration In the config file for each worker, you must specify: - * The type of worker (`worker_app`). The currently available worker applications are listed below. - * A unique name for the worker (`worker_name`). + * The type of worker ([`worker_app`](usage/configuration/config_documentation.md#worker_app)). + The currently available worker applications are listed [below](#available-worker-applications). + * A unique name for the worker ([`worker_name`](usage/configuration/config_documentation.md#worker_name)). * The HTTP replication endpoint that it should talk to on the main synapse process - (`worker_replication_host` and `worker_replication_http_port`) - * If handling HTTP requests, a `worker_listeners` option with an `http` - listener, in the same way as the [`listeners`](usage/configuration/config_documentation.md#listeners) - option in the shared config. + ([`worker_replication_host`](usage/configuration/config_documentation.md#worker_replication_host) and + [`worker_replication_http_port`](usage/configuration/config_documentation.md#worker_replication_http_port)). + * If handling HTTP requests, a [`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option + with an `http` listener. * If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for the main process (`worker_main_http_uri`). @@ -146,7 +151,6 @@ plain HTTP endpoint on port 8083 separately serving various endpoints, e.g. Obviously you should configure your reverse-proxy to route the relevant endpoints to the worker (`localhost:8083` in the above example). - ### Running Synapse with workers Finally, you need to start your worker processes. This can be done with either @@ -288,7 +292,8 @@ For multiple workers not handling the SSO endpoints properly, see [#9427](https://github.com/matrix-org/synapse/issues/9427). Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners) -with `client` and `federation` `resources` must be configured in the `worker_listeners` +with `client` and `federation` `resources` must be configured in the +[`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option in the worker config. #### Load balancing @@ -331,9 +336,10 @@ of the main process to a particular worker. To enable this, the worker must have a [HTTP `replication` listener](usage/configuration/config_documentation.md#listeners) configured, -have a `worker_name` and be listed in the `instance_map` config. The same worker -can handle multiple streams, but unless otherwise documented, each stream can only -have a single writer. +have a [`worker_name`](usage/configuration/config_documentation.md#worker_name) +and be listed in the [`instance_map`](usage/configuration/config_documentation.md#instance_map) +config. The same worker can handle multiple streams, but unless otherwise documented, +each stream can only have a single writer. For example, to move event persistence off to a dedicated worker, the shared configuration would include: @@ -360,9 +366,26 @@ streams and the endpoints associated with them: ##### The `events` stream -The `events` stream experimentally supports having multiple writers, where work -is sharded between them by room ID. Note that you *must* restart all worker -instances when adding or removing event persisters. An example `stream_writers` +The `events` stream experimentally supports having multiple writer workers, where load +is sharded between them by room ID. Each writer is called an _event persister_. They are +responsible for +- receiving new events, +- linking them to those already in the room [DAG](development/room-dag-concepts.md), +- persisting them to the DB, and finally +- updating the events stream. + +Because load is sharded in this way, you *must* restart all worker instances when +adding or removing event persisters. + +An `event_persister` should not be mistaken for an `event_creator`. +An `event_creator` listens for requests from clients to create new events and does +so. It will then pass those events over HTTP replication to any configured event +persisters (or the main process if none are configured). + +Note that `event_creator`s and `event_persister`s are implemented using the same +[`synapse.app.generic_worker`](#synapse.app.generic_worker). + +An example [`stream_writers`](usage/configuration/config_documentation.md#stream_writers) configuration with multiple writers: ```yaml @@ -416,16 +439,18 @@ worker. Background tasks are run periodically or started via replication. Exactl which tasks are configured to run depends on your Synapse configuration (e.g. if stats is enabled). This worker doesn't handle any REST endpoints itself. -To enable this, the worker must have a `worker_name` and can be configured to run -background tasks. For example, to move background tasks to a dedicated worker, -the shared configuration would include: +To enable this, the worker must have a unique +[`worker_name`](usage/configuration/config_documentation.md#worker_name) +and can be configured to run background tasks. For example, to move background tasks +to a dedicated worker, the shared configuration would include: ```yaml run_background_tasks_on: background_worker ``` -You might also wish to investigate the `update_user_directory_from_worker` and -`media_instance_running_background_jobs` settings. +You might also wish to investigate the +[`update_user_directory_from_worker`](#updating-the-user-directory) and +[`media_instance_running_background_jobs`](#synapseappmedia_repository) settings. An example for a dedicated background worker instance: @@ -478,13 +503,17 @@ worker application type. ### `synapse.app.pusher` Handles sending push notifications to sygnal and email. Doesn't handle any -REST endpoints itself, but you should set `start_pushers: False` in the +REST endpoints itself, but you should set +[`start_pushers: false`](usage/configuration/config_documentation.md#start_pushers) in the shared configuration file to stop the main synapse sending push notifications. -To run multiple instances at once the `pusher_instances` option should list all -pusher instances by their worker name, e.g.: +To run multiple instances at once the +[`pusher_instances`](usage/configuration/config_documentation.md#pusher_instances) +option should list all pusher instances by their +[`worker_name`](usage/configuration/config_documentation.md#worker_name), e.g.: ```yaml +start_pushers: false pusher_instances: - pusher_worker1 - pusher_worker2 @@ -512,15 +541,20 @@ Note this worker cannot be load-balanced: only one instance should be active. ### `synapse.app.federation_sender` Handles sending federation traffic to other servers. Doesn't handle any -REST endpoints itself, but you should set `send_federation: False` in the -shared configuration file to stop the main synapse sending this traffic. +REST endpoints itself, but you should set +[`send_federation: false`](usage/configuration/config_documentation.md#send_federation) +in the shared configuration file to stop the main synapse sending this traffic. If running multiple federation senders then you must list each -instance in the `federation_sender_instances` option by their `worker_name`. +instance in the +[`federation_sender_instances`](usage/configuration/config_documentation.md#federation_sender_instances) +option by their +[`worker_name`](usage/configuration/config_documentation.md#worker_name). All instances must be stopped and started when adding or removing instances. For example: ```yaml +send_federation: false federation_sender_instances: - federation_sender1 - federation_sender2 @@ -547,7 +581,9 @@ Handles the media repository. It can handle all endpoints starting with: ^/_synapse/admin/v1/quarantine_media/.*$ ^/_synapse/admin/v1/users/.*/media$ -You should also set `enable_media_repo: False` in the shared configuration +You should also set +[`enable_media_repo: False`](usage/configuration/config_documentation.md#enable_media_repo) +in the shared configuration file to stop the main synapse running background jobs related to managing the media repository. Note that doing so will prevent the main process from being able to handle the above endpoints. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index b62b3b9205..94d1150415 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -53,7 +53,7 @@ DEFAULT_LOG_CONFIG = Template( # Synapse also supports structured logging for machine readable logs which can # be ingested by ELK stacks. See [2] for details. # -# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [1]: https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema # [2]: https://matrix-org.github.io/synapse/latest/structured_logging.html version: 1 -- cgit 1.5.1 From 67583281e3f8ea923eedbc56a4c85c7ba75d1582 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Oct 2022 09:58:12 -0400 Subject: Fix tests for change in PostgreSQL 14 behavior change. (#14310) PostgreSQL 14 changed the behavior of `websearch_to_tsquery` to improve some behaviour. The tests were hitting those edge-cases about handling of hanging double quotes. This fixes the tests to take into account the PostgreSQL version. --- changelog.d/14310.feature | 1 + synapse/storage/databases/main/search.py | 5 ++--- tests/storage/test_room_search.py | 16 ++++++++++++---- 3 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14310.feature (limited to 'synapse') diff --git a/changelog.d/14310.feature b/changelog.d/14310.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/14310.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index a89fc54c2c..594b935614 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -824,9 +824,8 @@ def _tokenize_query(query: str) -> TokenList: in_phrase = False parts = deque(query.split('"')) for i, part in enumerate(parts): - # The contents inside double quotes is treated as a phrase, a trailing - # double quote is not implied. - in_phrase = bool(i % 2) and i != (len(parts) - 1) + # The contents inside double quotes is treated as a phrase. + in_phrase = bool(i % 2) # Pull out the individual words, discarding any non-word characters. words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 9ddc19900a..868b5bee84 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -239,7 +239,6 @@ class MessageSearchTest(HomeserverTestCase): ("fox -nope", (True, False)), ("fox -brown", (False, True)), ('"fox" quick', True), - ('"fox quick', True), ('"quick brown', True), ('" quick "', True), ('" nope"', False), @@ -269,6 +268,15 @@ class MessageSearchTest(HomeserverTestCase): response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) self.assertIn("event_id", response) + # The behaviour of a missing trailing double quote changed in PostgreSQL 14 + # from ignoring the initial double quote to treating it as a phrase. + main_store = homeserver.get_datastores().main + found = False + if isinstance(main_store.database_engine, PostgresEngine): + assert main_store.database_engine._version is not None + found = main_store.database_engine._version < 140000 + self.COMMON_CASES.append(('"fox quick', (found, True))) + def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" cases = ( @@ -280,9 +288,9 @@ class MessageSearchTest(HomeserverTestCase): ("fox -brown", ["fox", SearchToken.Not, "brown"]), ("- fox", [SearchToken.Not, "fox"]), ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), - # No trailing double quoe. - ('"fox quick', ["fox", SearchToken.And, "quick"]), - ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + # No trailing double quote. + ('"fox quick', [Phrase(["fox", "quick"])]), + ('"-fox quick', [Phrase(["-fox", "quick"])]), ('" quick "', [Phrase(["quick"])]), ( 'q"uick brow"n', -- cgit 1.5.1 From aa70556699e649f46f51a198fb104eecdc0d311b Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 27 Oct 2022 13:29:23 -0500 Subject: Check appservice user interest against the local users instead of all users (`get_users_in_room` mis-use) (#13958) --- changelog.d/13958.bugfix | 1 + docs/upgrade.md | 19 ++++ synapse/appservice/__init__.py | 16 ++- synapse/storage/databases/main/appservice.py | 17 ++- synapse/storage/databases/main/roommember.py | 3 + tests/appservice/test_appservice.py | 10 +- tests/handlers/test_appservice.py | 162 ++++++++++++++++++++++++++- 7 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13958.bugfix (limited to 'synapse') diff --git a/changelog.d/13958.bugfix b/changelog.d/13958.bugfix new file mode 100644 index 0000000000..f9f651bfdc --- /dev/null +++ b/changelog.d/13958.bugfix @@ -0,0 +1 @@ +Check appservice user interest against the local users instead of all users in the room to align with [MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905). diff --git a/docs/upgrade.md b/docs/upgrade.md index 78c34d0c15..f095bbc3a6 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -97,6 +97,25 @@ As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_s Modules relying on it can instead use the `create_login_token` method. +## Changes to the events received by application services (interest) + +To align with spec (changed in +[MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905)), Synapse now +only considers local users to be interesting. In other words, the `users` namespace +regex is only be applied against local users of the homeserver. + +Please note, this probably doesn't affect the expected behavior of your application +service, since an interesting local user in a room still means all messages in the room +(from local or remote users) will still be considered interesting. And matching a room +with the `rooms` or `aliases` namespace regex will still consider all events sent in the +room to be interesting to the application service. + +If one of your application service's `users` regex was intending to match a remote user, +this will no longer match as you expect. The behavioral mismatch between matching all +local users and some remote users is why the spec was changed/clarified and this +caveat is no longer supported. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 0dfa00df44..500bdde3a9 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -172,12 +172,24 @@ class ApplicationService: Returns: True if this service would like to know about this room. """ - member_list = await store.get_users_in_room( + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + # + # In the future, we can consider re-using + # `store.get_app_service_users_in_room` which is very similar to this + # function but has a slightly worse performance than this because we + # have an early escape-hatch if we find a single user that the + # appservice is interested in. The juice would be worth the squeeze if + # `store.get_app_service_users_in_room` was used in more places besides + # an experimental MSC. But for now we can avoid doing more work and + # barely using it later. + local_user_ids = await store.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) # check joined member events - for user_id in member_list: + for user_id in local_user_ids: if self.is_interested_in_user(user_id): return True return False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 64b70a7b28..63046c0527 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): app_service: "ApplicationService", cache_context: _CacheContext, ) -> List[str]: - users_in_room = await self.get_users_in_room( + """ + Get all users in a room that the appservice controls. + + Args: + room_id: The room to check in. + app_service: The application service to check interest/control against + + Returns: + List of user IDs that the appservice controls. + """ + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + local_users_in_room = await self.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) - return list(filter(app_service.is_interested_in_user, users_in_room)) + return list(filter(app_service.is_interested_in_user, local_users_in_room)) class ApplicationServiceStore(ApplicationServiceWorkerStore): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ab708b0ba5..e56a13f21e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): the forward extremities of those rooms will exclude most members. We may also calculate room state incorrectly for such rooms and believe that a member is or is not in the room when the opposite is true. + + Note: If you only care about users in the room local to the homeserver, use + `get_local_users_in_room(...)` instead which will be more performant. """ return await self.db_pool.simple_select_onecol( table="current_state_events", diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 3018d3fc6f..d4dccfc2f0 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store = Mock() self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): @@ -129,7 +129,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -184,7 +184,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertFalse( ( yield defer.ensureDeferred( @@ -203,7 +203,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -236,7 +236,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_users_in_room = simple_async_mock( + self.store.get_local_users_in_room = simple_async_mock( ["@alice:here", "@irc_fo:here", "@bob:here"] ) self.store.get_aliases_for_room = simple_async_mock([]) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7e4570f990..144e49d0fd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.api.constants import EduTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, TransactionOneTimeKeyCounts, @@ -36,7 +36,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import make_awaitable, simple_async_mock +from tests.test_utils import event_injection, make_awaitable, simple_async_mock from tests.unittest import override_config from tests.utils import MockClock @@ -390,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) @@ -416,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def _notify_interested_services(self): + # This is normally set in `notify_interested_services` but we need to call the + # internal async version so the reactor gets pushed to completion. + self.hs.get_application_service_handler().current_max += 1 + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services( + RoomStreamToken( + None, self.hs.get_application_service_handler().current_max + ) + ) + ) + + @parameterized.expand( + [ + ("@local_as_user:test", True), + # Defining remote users in an application service user namespace regex is a + # footgun since the appservice might assume that it'll receive all events + # sent by that remote user, but it will only receive events in rooms that + # are shared with a local user. So we just remove this footgun possibility + # entirely and we won't notify the application service based on remote + # users. + ("@remote_as_user:remote", False), + ] + ) + def test_match_interesting_room_members( + self, interesting_user: str, should_notify: bool + ): + """ + Test to make sure that a interesting user (local or remote) in the room is + notified as expected when someone else in the room sends a message. + """ + # Register an application service that's interested in the `interesting_user` + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": interesting_user, + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # Join the interesting user to the room + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, interesting_user, "join" + ) + ) + # Kick the appservice into checking this membership event to get the event out + # of the way + self._notify_interested_services() + # We don't care about the interesting user join event (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from an uninteresting user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from uninteresting user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + if should_notify: + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Even though the message came from an uninteresting user, it should still + # notify us because the interesting user is joined to the room where the + # message was sent. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + else: + self.send_mock.assert_not_called() + + def test_application_services_receive_events_sent_by_interesting_local_user(self): + """ + Test to make sure that a messages sent from a local user can be interesting and + picked up by the appservice. + """ + # Register an application service that's interested in all local users + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": ".*", + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # We don't care about interesting events before this (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from the interesting local user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from interesting local user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Events sent from an interesting local user should also be picked up as + # interesting to the appservice. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + def test_sending_read_receipt_batches_to_application_services(self): """Tests that a large batch of read receipts are sent correctly to interested application services. -- cgit 1.5.1 From 6a6e1e8c0711939338f25d8d41d1e4d33d984949 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 28 Oct 2022 10:53:34 +0000 Subject: Fix room creation being rate limited too aggressively since Synapse v1.69.0. (#14314) * Introduce a test for the old behaviour which we want to restore * Reintroduce the old behaviour in a simpler way * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Use 1 credit instead of 2 for creating a room: be more lenient than before Notably, the UI in Element Web was still broken after restoring to prior behaviour. After discussion, we agreed that it would be sensible to increase the limit. Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/14314.bugfix | 1 + synapse/api/ratelimiting.py | 8 +++++- synapse/handlers/room.py | 16 ++++++++---- tests/rest/client/test_rooms.py | 54 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14314.bugfix (limited to 'synapse') diff --git a/changelog.d/14314.bugfix b/changelog.d/14314.bugfix new file mode 100644 index 0000000000..8be47ee083 --- /dev/null +++ b/changelog.d/14314.bugfix @@ -0,0 +1 @@ +Fix room creation being rate limited too aggressively since Synapse v1.69.0. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 044c7d4926..511790c7c5 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -343,6 +343,7 @@ class RequestRatelimiter: requester: Requester, update: bool = True, is_admin_redaction: bool = False, + n_actions: int = 1, ) -> None: """Ratelimits requests. @@ -355,6 +356,8 @@ class RequestRatelimiter: is_admin_redaction: Whether this is a room admin/moderator redacting an event. If so then we may apply different ratelimits depending on config. + n_actions: Multiplier for the number of actions to apply to the + rate limiter at once. Raises: LimitExceededError if the request should be ratelimited @@ -383,7 +386,9 @@ class RequestRatelimiter: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + await self.admin_redaction_ratelimiter.ratelimit( + requester, update=update, n_actions=n_actions + ) else: # Override rate and burst count per-user await self.request_ratelimiter.ratelimit( @@ -391,4 +396,5 @@ class RequestRatelimiter: rate_hz=messages_per_second, burst_count=burst_count, update=update, + n_actions=n_actions, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..d74b675adc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -559,7 +559,6 @@ class RoomCreationHandler: invite_list=[], initial_state=initial_state, creation_content=creation_content, - ratelimit=False, ) # Transfer membership events @@ -753,6 +752,10 @@ class RoomCreationHandler: ) if ratelimit: + # Rate limit once in advance, but don't rate limit the individual + # events in the room — room creation isn't atomic and it's very + # janky if half the events in the initial state don't make it because + # of rate limiting. await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( @@ -913,7 +916,6 @@ class RoomCreationHandler: room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, - ratelimit=ratelimit, ) if "name" in config: @@ -1037,7 +1039,6 @@ class RoomCreationHandler: room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, - ratelimit: bool = True, ) -> Tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1046,6 +1047,8 @@ class RoomCreationHandler: `power_level_content_override` doesn't apply when initial state has power level state event content. + Rate limiting should already have been applied by this point. + Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. @@ -1144,7 +1147,7 @@ class RoomCreationHandler: creator.user, room_id, "join", - ratelimit=ratelimit, + ratelimit=False, content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], @@ -1269,7 +1272,10 @@ class RoomCreationHandler: events_to_send.append((encryption_event, encryption_context)) last_event = await self.event_creation_handler.handle_new_client_event( - creator, events_to_send, ignore_shadow_ban=True + creator, + events_to_send, + ignore_shadow_ban=True, + ratelimit=False, ) assert last_event.internal_metadata.stream_ordering is not None return last_event.internal_metadata.stream_ordering, last_event.event_id, depth diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 716366eb90..1084d4ad9d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -54,6 +54,7 @@ from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event +from tests.unittest import override_config PATH_PREFIX = b"/_matrix/client/api/v1" @@ -871,6 +872,41 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def _create_basic_room(self) -> Tuple[int, object]: + """ + Tries to create a basic room and returns the response code. + """ + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + return channel.code, channel.json_body + + @override_config( + { + "rc_message": {"per_second": 0.2, "burst_count": 10}, + } + ) + def test_room_creation_ratelimiting(self) -> None: + """ + Regression test for #14312, where ratelimiting was made too strict. + Clients should be able to create 10 rooms in a row + without hitting rate limits, using default rate limit config. + (We override rate limiting config back to its default value.) + + To ensure we don't make ratelimiting too generous accidentally, + also check that we can't create an 11th room. + """ + + for _ in range(10): + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.OK, json_body) + + # The 6th room hits the rate limit. + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -1390,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + + joiner_user_id = self.register_user("joiner", "secret") + # Now make a new user try to join some of them. - self.helper.create_room_as(self.user_id, expect_code=429) + # The user can join 3 rooms + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id) + + # But the user cannot join a 4th room + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} -- cgit 1.5.1 From 81815e0561eea91dbf0c29731589fac2e6f98a40 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Oct 2022 11:44:10 -0400 Subject: Switch search SQL to triple-quote strings. (#14311) For ease of reading we switch from concatenated strings to triple quote strings. --- changelog.d/14311.feature | 1 + synapse/storage/databases/main/search.py | 188 ++++++++++++++++--------------- 2 files changed, 100 insertions(+), 89 deletions(-) create mode 100644 changelog.d/14311.feature (limited to 'synapse') diff --git a/changelog.d/14311.feature b/changelog.d/14311.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/14311.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 594b935614..e9588d1755 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -80,11 +80,11 @@ class SearchWorkerStore(SQLBaseStore): if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) + sql = """ + INSERT INTO event_search + (event_id, room_id, key, vector, stream_ordering, origin_server_ts) + VALUES (?,?,?,to_tsvector('english', ?),?,?) + """ args1 = ( ( @@ -101,20 +101,20 @@ class SearchWorkerStore(SQLBaseStore): txn.execute_batch(sql, args1) elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args2 = ( - ( - entry.event_id, - entry.room_id, - entry.key, - _clean_value_for_search(entry.value), - ) - for entry in entries + self.db_pool.simple_insert_many_txn( + txn, + table="event_search", + keys=("event_id", "room_id", "key", "value"), + values=( + ( + entry.event_id, + entry.room_id, + entry.key, + _clean_value_for_search(entry.value), + ) + for entry in entries + ), ) - txn.execute_batch(sql, args2) else: # This should be unreachable. @@ -162,15 +162,17 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn: LoggingTransaction) -> int: - sql = ( - "SELECT stream_ordering, event_id, room_id, type, json, " - " origin_server_ts FROM events" - " JOIN event_json USING (room_id, event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " AND (%s)" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) + sql = """ + SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts + FROM events + JOIN event_json USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND (%s) + ORDER BY stream_ordering DESC + LIMIT ? + """ % ( + " OR ".join("type = '%s'" % (t,) for t in TYPES), + ) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) @@ -284,8 +286,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): try: c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx" - " ON event_search USING GIN (vector)" + """ + CREATE INDEX CONCURRENTLY event_search_fts_idx + ON event_search USING GIN (vector) + """ ) except psycopg2.ProgrammingError as e: logger.warning( @@ -323,12 +327,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # We create with NULLS FIRST so that when we search *backwards* # we get the ones with non null origin_server_ts *first* c.execute( - "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" - "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_room_order + ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) c.execute( - "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" - "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_order + ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) conn.set_session(autocommit=False) @@ -345,14 +353,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: - sql = ( - "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," - " origin_server_ts = e.origin_server_ts" - " FROM events AS e" - " WHERE e.event_id = es.event_id" - " AND ? <= e.stream_ordering AND e.stream_ordering < ?" - " RETURNING es.stream_ordering" - ) + sql = """ + UPDATE event_search AS es + SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts + FROM events AS e + WHERE e.event_id = es.event_id + AND ? <= e.stream_ordering AND e.stream_ordering < ? + RETURNING es.stream_ordering + """ min_stream_id = max_stream_id - batch_size txn.execute(sql, (min_stream_id, max_stream_id)) @@ -456,33 +464,33 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term tsquery_func = self.database_engine.tsquery_func - sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" - ) + sql = f""" + SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank, + room_id, event_id + FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" - ) + count_sql = f""" + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_for_sqlite(search_term) - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) + sql = """ + SELECT rank(matchinfo(event_search)) as rank, room_id, event_id + FROM event_search + WHERE value MATCH ? + """ args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? + """ count_args = [search_query] + count_args else: # This should be unreachable. @@ -588,26 +596,27 @@ class SearchStore(SearchBackgroundUpdateStore): raise SynapseError(400, "Invalid pagination token") clauses.append( - "(origin_server_ts < ?" - " OR (origin_server_ts = ? AND stream_ordering < ?))" + """ + (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?)) + """ ) args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): search_query = search_term tsquery_func = self.database_engine.tsquery_func - sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " - ) + sql = f""" + SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank, + origin_server_ts, stream_ordering, room_id, event_id + FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) AND + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " - ) + count_sql = f""" + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) AND + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -619,23 +628,24 @@ class SearchStore(SearchBackgroundUpdateStore): # in the events table to get the topological ordering. We need # to use the indexes in this order because sqlite refuses to # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " + sql = """ + SELECT + rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering + FROM ( + SELECT key, event_id, matchinfo(event_search) as matchinfo + FROM event_search + WHERE value MATCH ? ) + CROSS JOIN events USING (event_id) + WHERE + """ search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? AND + """ count_args = [search_query] + count_args else: # This should be unreachable. @@ -647,10 +657,10 @@ class SearchStore(SearchBackgroundUpdateStore): # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. if isinstance(self.database_engine, PostgresEngine): - sql += ( - " ORDER BY origin_server_ts DESC NULLS LAST," - " stream_ordering DESC NULLS LAST LIMIT ?" - ) + sql += """ + ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST + LIMIT ? + """ elif isinstance(self.database_engine, Sqlite3Engine): sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" else: -- cgit 1.5.1 From 730b13dbc9e48181b1aaf38be870ec21364b1e9c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 28 Oct 2022 17:04:02 +0100 Subject: Improve `RawHeaders` type hints (#14303) --- changelog.d/14303.misc | 1 + synapse/app/generic_worker.py | 8 ++++---- synapse/http/client.py | 24 +++++++++++++++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14303.misc (limited to 'synapse') diff --git a/changelog.d/14303.misc b/changelog.d/14303.misc new file mode 100644 index 0000000000..24ce238223 --- /dev/null +++ b/changelog.d/14303.misc @@ -0,0 +1 @@ +Improve type hinting of `RawHeaders`. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2a9f039367..cb5892f041 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -178,13 +178,13 @@ class KeyUploadServlet(RestServlet): # Proxy headers from the original request, such as the auth headers # (in case the access token is there) and the original IP / # User-Agent of the request. - headers = { - header: request.requestHeaders.getRawHeaders(header, []) + headers: Dict[bytes, List[bytes]] = { + header: list(request.requestHeaders.getRawHeaders(header, [])) for header in (b"Authorization", b"User-Agent") } # Add the previous hop to the X-Forwarded-For header. - x_forwarded_for = request.requestHeaders.getRawHeaders( - b"X-Forwarded-For", [] + x_forwarded_for = list( + request.requestHeaders.getRawHeaders(b"X-Forwarded-For", []) ) # we use request.client here, since we want the previous hop, not the # original client (as returned by request.getClientAddress()). diff --git a/synapse/http/client.py b/synapse/http/client.py index 084d0a5b84..4eb740c040 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -25,7 +25,6 @@ from typing import ( List, Mapping, Optional, - Sequence, Tuple, Union, ) @@ -90,14 +89,29 @@ incoming_responses_counter = Counter( "synapse_http_client_responses", "", ["method", "code"] ) -# the type of the headers list, to be passed to the t.w.h.Headers. -# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so -# we simplify. +# the type of the headers map, to be passed to the t.w.h.Headers. +# +# The actual type accepted by Twisted is +# Mapping[Union[str, bytes], Sequence[Union[str, bytes]] , +# allowing us to mix and match str and bytes freely. However: any str is also a +# Sequence[str]; passing a header string value which is a +# standalone str is interpreted as a sequence of 1-codepoint strings. This is a disastrous footgun. +# We use a narrower value type (RawHeaderValue) to avoid this footgun. +# +# We also simplify the keys to be either all str or all bytes. This helps because +# Dict[K, V] is invariant in K (and indeed V). RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] # the value actually has to be a List, but List is invariant so we can't specify that # the entries can either be Lists or bytes. -RawHeaderValue = Sequence[Union[str, bytes]] +RawHeaderValue = Union[ + List[str], + List[bytes], + List[Union[str, bytes]], + Tuple[str, ...], + Tuple[bytes, ...], + Tuple[Union[str, bytes], ...], +] def check_against_blacklist( -- cgit 1.5.1 From 7911e2835df7b4bf1dec98b09da89beda65e2ab2 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 28 Oct 2022 18:06:02 +0100 Subject: Prevent federation user keys query from returning device names if disallowed (#14304) --- changelog.d/14304.bugfix | 1 + synapse/handlers/e2e_keys.py | 37 ++++++++++++++++++++--- synapse/storage/databases/main/end_to_end_keys.py | 17 ++++++++--- 3 files changed, 46 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14304.bugfix (limited to 'synapse') diff --git a/changelog.d/14304.bugfix b/changelog.d/14304.bugfix new file mode 100644 index 0000000000..b8d4d91034 --- /dev/null +++ b/changelog.d/14304.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.34.0 where device names would be returned via a federation user key query request when `allow_device_name_lookup_over_federation` was set to `false`. \ No newline at end of file diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 09a2492afc..a9912c467d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -49,6 +49,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler: def __init__(self, hs: "HomeServer"): + self.config = hs.config self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -431,13 +432,17 @@ class E2eKeysHandler: @trace @cancellable async def query_local_devices( - self, query: Mapping[str, Optional[List[str]]] + self, + query: Mapping[str, Optional[List[str]]], + include_displaynames: bool = True, ) -> Dict[str, Dict[str, dict]]: """Get E2E device keys for local users Args: query: map from user_id to a list of devices to query (None for all devices) + include_displaynames: Whether to include device displaynames in the returned + device details. Returns: A map from user_id -> device_id -> device details @@ -469,7 +474,9 @@ class E2eKeysHandler: # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = await self.store.get_e2e_device_keys_for_cs_api(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api( + local_query, include_displaynames + ) # Build the result structure for user_id, device_keys in results.items(): @@ -482,11 +489,33 @@ class E2eKeysHandler: async def on_federation_query_client_keys( self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: - """Handle a device key query from a federated server""" + """Handle a device key query from a federated server: + + Handles the path: GET /_matrix/federation/v1/users/keys/query + + Args: + query_body: The body of the query request. Should contain a key + "device_keys" that map to a dictionary of user ID's -> list of + device IDs. If the list of device IDs is empty, all devices of + that user will be queried. + + Returns: + A json dictionary containing the following: + - device_keys: A dictionary containing the requested device information. + - master_keys: An optional dictionary of user ID -> master cross-signing + key info. + - self_signing_key: An optional dictionary of user ID -> self-signing + key info. + """ device_keys_query: Dict[str, Optional[List[str]]] = query_body.get( "device_keys", {} ) - res = await self.query_local_devices(device_keys_query) + res = await self.query_local_devices( + device_keys_query, + include_displaynames=( + self.config.federation.allow_device_name_lookup_over_federation + ), + ) ret = {"device_keys": res} # add in the cross-signing keys diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 8a10ae800c..2a4f58ed92 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -139,11 +139,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @trace @cancellable async def get_e2e_device_keys_for_cs_api( - self, query_list: List[Tuple[str, Optional[str]]] + self, + query_list: List[Tuple[str, Optional[str]]], + include_displaynames: bool = True, ) -> Dict[str, Dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. Args: - query_list(list): List of pairs of user_ids and device_ids. + query_list: List of pairs of user_ids and device_ids. + include_displaynames: Whether to include the displayname of returned devices + (if one exists). Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -166,9 +170,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker continue r["unsigned"] = {} - display_name = device_info.display_name - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name + if include_displaynames: + # Include the device's display name in the "unsigned" dictionary + display_name = device_info.display_name + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + rv[user_id][device_id] = r return rv -- cgit 1.5.1 From 2bb2c32e8ed5642a5bf3ba1e8c49e10cecc88905 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 13:02:07 +0000 Subject: Avoid incrementing bg process utime/stime counters by negative durations (#14323) --- changelog.d/14323.bugfix | 1 + mypy.ini | 4 +- synapse/metrics/background_process_metrics.py | 6 +- tests/metrics/__init__.py | 0 tests/metrics/test_background_process_metrics.py | 19 +++ tests/metrics/test_metrics.py | 206 +++++++++++++++++++++++ tests/test_metrics.py | 200 ---------------------- 7 files changed, 233 insertions(+), 203 deletions(-) create mode 100644 changelog.d/14323.bugfix create mode 100644 tests/metrics/__init__.py create mode 100644 tests/metrics/test_background_process_metrics.py create mode 100644 tests/metrics/test_metrics.py delete mode 100644 tests/test_metrics.py (limited to 'synapse') diff --git a/changelog.d/14323.bugfix b/changelog.d/14323.bugfix new file mode 100644 index 0000000000..da39bc020c --- /dev/null +++ b/changelog.d/14323.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 0.34.0rc2 where logs could include error spam when background processes are measured as taking a negative amount of time. diff --git a/mypy.ini b/mypy.ini index 34b4523e00..8f1141a239 100644 --- a/mypy.ini +++ b/mypy.ini @@ -56,7 +56,6 @@ exclude = (?x) |tests/rest/media/v1/test_media_storage.py |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py - |tests/test_metrics.py |tests/test_state.py |tests/test_terms_auth.py |tests/util/caches/test_cached_call.py @@ -106,6 +105,9 @@ disallow_untyped_defs = False [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.metrics.test_background_process_metrics] +disallow_untyped_defs = True + [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 7a1516d3a8..9ea4e23b31 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -174,8 +174,10 @@ class _BackgroundProcess: diff = new_stats - self._reported_stats self._reported_stats = new_stats - _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) - _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) + # For unknown reasons, the difference in times can be negative. See comment in + # synapse.http.request_metrics.RequestMetrics.update_metrics. + _background_process_ru_utime.labels(self.desc).inc(max(diff.ru_utime, 0)) + _background_process_ru_stime.labels(self.desc).inc(max(diff.ru_stime, 0)) _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) _background_process_db_txn_duration.labels(self.desc).inc( diff.db_txn_duration_sec diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metrics/test_background_process_metrics.py b/tests/metrics/test_background_process_metrics.py new file mode 100644 index 0000000000..f0f6cb2912 --- /dev/null +++ b/tests/metrics/test_background_process_metrics.py @@ -0,0 +1,19 @@ +from unittest import TestCase as StdlibTestCase +from unittest.mock import Mock + +from synapse.logging.context import ContextResourceUsage, LoggingContext +from synapse.metrics.background_process_metrics import _BackgroundProcess + + +class TestBackgroundProcessMetrics(StdlibTestCase): + def test_update_metrics_with_negative_time_diff(self) -> None: + """We should ignore negative reported utime and stime differences""" + usage = ContextResourceUsage() + usage.ru_stime = usage.ru_utime = -1.0 + + mock_logging_context = Mock(spec=LoggingContext) + mock_logging_context.get_resource_usage.return_value = usage + + process = _BackgroundProcess("test process", mock_logging_context) + # Should not raise + process.update_metrics() diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 0000000000..bddc4228bc --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,206 @@ +# Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing_extensions import Protocol + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + +from unittest.mock import patch + +from pkg_resources import parse_version + +from synapse.app._base import _set_prometheus_client_use_created_metrics +from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.deferred_cache import DeferredCache + +from tests import unittest + + +def get_sample_labels_value(sample): + """Extract the labels and values of a sample. + + prometheus_client 0.5 changed the sample type to a named tuple with more + members than the plain tuple had in 0.4 and earlier. This function can + extract the labels and value from the sample for both sample types. + + Args: + sample: The sample to get the labels and value from. + Returns: + A tuple of (labels, value) from the sample. + """ + + # If the sample has a labels and value attribute, use those. + if hasattr(sample, "labels") and hasattr(sample, "value"): + return sample.labels, sample.value + # Otherwise fall back to treating it as a plain 3 tuple. + else: + _, labels, value = sample + return labels, value + + +class TestMauLimit(unittest.TestCase): + def test_basic(self): + class MetricEntry(Protocol): + foo: int + bar: int + + gauge: InFlightGauge[MetricEntry] = InFlightGauge( + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] + ) + + def handle1(metrics): + metrics.foo += 2 + metrics.bar = max(metrics.bar, 5) + + def handle2(metrics): + metrics.foo += 3 + metrics.bar = max(metrics.bar, 7) + + gauge.register(("key1",), handle1) + + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.unregister(("key1",), handle1) + + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.register(("key1",), handle1) + gauge.register(("key2",), handle2) + + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.unregister(("key2",), handle2) + gauge.register(("key1",), handle2) + + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) + + def get_metrics_from_gauge(self, gauge): + results = {} + + for r in gauge.collect(): + results[r.name] = { + tuple(labels[x] for x in gauge.labels): value + for labels, value in map(get_sample_labels_value, r.samples) + } + + return results + + +class BuildInfoTests(unittest.TestCase): + def test_get_build(self): + """ + The synapse_build_info metric reports the OS version, Python version, + and Synapse version. + """ + items = list( + filter( + lambda x: b"synapse_build_info{" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + ) + self.assertEqual(len(items), 1) + self.assertTrue(b"osversion=" in items[0]) + self.assertTrue(b"pythonversion=" in items[0]) + self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache: DeferredCache[str, str] = DeferredCache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + +class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase): + if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"): + skip = "prometheus-client too old" + + def test_created_metrics_disabled(self) -> None: + """ + Tests that a brittle hack, to disable `_created` metrics, works. + This involves poking at the internals of prometheus-client. + It's not the end of the world if this doesn't work. + + This test gives us a way to notice if prometheus-client changes + their internals. + """ + import prometheus_client.metrics + + PRIVATE_FLAG_NAME = "_use_created" + + # By default, the pesky `_created` metrics are enabled. + # Check this assumption is still valid. + self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME)) + + with patch("prometheus_client.metrics") as mock: + setattr(mock, PRIVATE_FLAG_NAME, True) + _set_prometheus_client_use_created_metrics(False) + self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False)) diff --git a/tests/test_metrics.py b/tests/test_metrics.py deleted file mode 100644 index 1a70eddc9b..0000000000 --- a/tests/test_metrics.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2018 New Vector Ltd -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -try: - from importlib import metadata -except ImportError: - import importlib_metadata as metadata # type: ignore[no-redef] - -from unittest.mock import patch - -from pkg_resources import parse_version - -from synapse.app._base import _set_prometheus_client_use_created_metrics -from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.deferred_cache import DeferredCache - -from tests import unittest - - -def get_sample_labels_value(sample): - """Extract the labels and values of a sample. - - prometheus_client 0.5 changed the sample type to a named tuple with more - members than the plain tuple had in 0.4 and earlier. This function can - extract the labels and value from the sample for both sample types. - - Args: - sample: The sample to get the labels and value from. - Returns: - A tuple of (labels, value) from the sample. - """ - - # If the sample has a labels and value attribute, use those. - if hasattr(sample, "labels") and hasattr(sample, "value"): - return sample.labels, sample.value - # Otherwise fall back to treating it as a plain 3 tuple. - else: - _, labels, value = sample - return labels, value - - -class TestMauLimit(unittest.TestCase): - def test_basic(self): - gauge = InFlightGauge( - "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] - ) - - def handle1(metrics): - metrics.foo += 2 - metrics.bar = max(metrics.bar, 5) - - def handle2(metrics): - metrics.foo += 3 - metrics.bar = max(metrics.bar, 7) - - gauge.register(("key1",), handle1) - - self.assert_dict( - { - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.unregister(("key1",), handle1) - - self.assert_dict( - { - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.register(("key1",), handle1) - gauge.register(("key2",), handle2) - - self.assert_dict( - { - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.unregister(("key2",), handle2) - gauge.register(("key1",), handle2) - - self.assert_dict( - { - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, - self.get_metrics_from_gauge(gauge), - ) - - def get_metrics_from_gauge(self, gauge): - results = {} - - for r in gauge.collect(): - results[r.name] = { - tuple(labels[x] for x in gauge.labels): value - for labels, value in map(get_sample_labels_value, r.samples) - } - - return results - - -class BuildInfoTests(unittest.TestCase): - def test_get_build(self): - """ - The synapse_build_info metric reports the OS version, Python version, - and Synapse version. - """ - items = list( - filter( - lambda x: b"synapse_build_info{" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - ) - self.assertEqual(len(items), 1) - self.assertTrue(b"osversion=" in items[0]) - self.assertTrue(b"pythonversion=" in items[0]) - self.assertTrue(b"version=" in items[0]) - - -class CacheMetricsTests(unittest.HomeserverTestCase): - def test_cache_metric(self): - """ - Caches produce metrics reflecting their state when scraped. - """ - CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = DeferredCache(CACHE_NAME, max_entries=777) - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - cache.prefill("1", "hi") - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - -class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase): - if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"): - skip = "prometheus-client too old" - - def test_created_metrics_disabled(self) -> None: - """ - Tests that a brittle hack, to disable `_created` metrics, works. - This involves poking at the internals of prometheus-client. - It's not the end of the world if this doesn't work. - - This test gives us a way to notice if prometheus-client changes - their internals. - """ - import prometheus_client.metrics - - PRIVATE_FLAG_NAME = "_use_created" - - # By default, the pesky `_created` metrics are enabled. - # Check this assumption is still valid. - self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME)) - - with patch("prometheus_client.metrics") as mock: - setattr(mock, PRIVATE_FLAG_NAME, True) - _set_prometheus_client_use_created_metrics(False) - self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False)) -- cgit 1.5.1 From cc3a52b33df72bb4230367536b924a6d1f510d36 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 31 Oct 2022 18:07:30 +0100 Subject: Support OIDC backchannel logouts (#11414) If configured an OIDC IdP can log a user's session out of Synapse when they log out of the identity provider. The IdP sends a request directly to Synapse (and must be configured with an endpoint) when a user logs out. --- changelog.d/11414.feature | 1 + docs/openid.md | 14 + docs/usage/configuration/config_documentation.md | 9 + synapse/config/oidc.py | 12 + synapse/handlers/oidc.py | 381 ++++++++++++++++++-- synapse/handlers/sso.py | 71 ++++ synapse/rest/synapse/client/oidc/__init__.py | 4 + .../client/oidc/backchannel_logout_resource.py | 35 ++ synapse/storage/databases/main/registration.py | 21 ++ tests/rest/client/test_auth.py | 390 +++++++++++++++++++-- tests/rest/client/utils.py | 55 ++- tests/server.py | 6 + tests/test_utils/oidc.py | 27 +- 13 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 changelog.d/11414.feature create mode 100644 synapse/rest/synapse/client/oidc/backchannel_logout_resource.py (limited to 'synapse') diff --git a/changelog.d/11414.feature b/changelog.d/11414.feature new file mode 100644 index 0000000000..fc035e50a7 --- /dev/null +++ b/changelog.d/11414.feature @@ -0,0 +1 @@ +Support back-channel logouts from OpenID Connect providers. diff --git a/docs/openid.md b/docs/openid.md index 87ebea4c29..37c5eb244d 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -49,6 +49,13 @@ setting in your configuration file. See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as the text below for example configurations for specific providers. +## OIDC Back-Channel Logout + +Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications. + +This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session. +This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` + ## Sample configs Here are a few configs for providers that should work with Synapse. @@ -123,6 +130,9 @@ oidc_providers: [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. +Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak. +This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak. + Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. 1. Click `Clients` in the sidebar and click `Create` @@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Client Protocol | `openid-connect` | | Access Type | `confidential` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | +| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` | +| Backchannel Logout Session Required (optional) | `On` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -167,7 +179,9 @@ oidc_providers: config: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" + backchannel_logout_enabled: true # Optional ``` + ### Auth0 [Auth0][auth0] is a hosted SaaS IdP solution. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 97fb505a5f..44358faf59 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3021,6 +3021,15 @@ Options for each entry include: which is set to the claims returned by the UserInfo Endpoint and/or in the ID Token. +* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications. + Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`. + Defaults to `false`. + +* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the + `sub` claim matches the subject claim received during login. This check can be disabled by setting + this to `true`. Defaults to `false`. + + You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`. It is possible to configure Synapse to only allow logins if certain attributes match particular values in the OIDC userinfo. The requirements can be listed under diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 5418a332da..0bd83f4010 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "userinfo_endpoint": {"type": "string"}, "jwks_uri": {"type": "string"}, "skip_verification": {"type": "boolean"}, + "backchannel_logout_enabled": {"type": "boolean"}, + "backchannel_logout_ignore_sub": {"type": "boolean"}, "user_profile_method": { "type": "string", "enum": ["auto", "userinfo_endpoint"], @@ -292,6 +294,10 @@ def _parse_oidc_config_dict( token_endpoint=oidc_config.get("token_endpoint"), userinfo_endpoint=oidc_config.get("userinfo_endpoint"), jwks_uri=oidc_config.get("jwks_uri"), + backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False), + backchannel_logout_ignore_sub=oidc_config.get( + "backchannel_logout_ignore_sub", False + ), skip_verification=oidc_config.get("skip_verification", False), user_profile_method=oidc_config.get("user_profile_method", "auto"), allow_existing_users=oidc_config.get("allow_existing_users", False), @@ -368,6 +374,12 @@ class OidcProviderConfig: # "openid" scope is used. jwks_uri: Optional[str] + # Whether Synapse should react to backchannel logouts + backchannel_logout_enabled: bool + + # Whether Synapse should ignore the `sub` claim in backchannel logouts or not. + backchannel_logout_ignore_sub: bool + # Whether to skip metadata verification skip_verification: bool diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 9759daf043..867973dcca 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import binascii import inspect +import json import logging -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import urlencode, urlparse import attr +import unpaddedbase64 from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, jwt +from authlib.jose import JsonWebToken, JWTClaims +from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, UserInfo @@ -35,9 +49,12 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.server import finish_request +from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -88,6 +105,8 @@ class Token(TypedDict): #: there is no real point of doing this in our case. JWK = Dict[str, str] +C = TypeVar("C") + #: A JWK Set, as per RFC7517 sec 5. class JWKS(TypedDict): @@ -247,6 +266,80 @@ class OidcHandler: await oidc_provider.handle_oidc_callback(request, session_data, code) + async def handle_backchannel_logout(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + This extracts the logout_token from the request and tries to figure out + which OpenID Provider it is comming from. This works by matching the iss claim + with the issuer and the aud claim with the client_id. + + Since at this point we don't know who signed the JWT, we can't just + decode it using authlib since it will always verifies the signature. We + have to decode it manually without validating the signature. The actual JWT + verification is done in the `OidcProvider.handler_backchannel_logout` method, + once we figured out which provider sent the request. + + Args: + request: the incoming request from the browser. + """ + logout_token = parse_string(request, "logout_token") + if logout_token is None: + raise SynapseError(400, "Missing logout_token in request") + + # A JWT looks like this: + # header.payload.signature + # where all parts are encoded with urlsafe base64. + # The aud and iss claims we care about are in the payload part, which + # is a JSON object. + try: + # By destructuring the list after splitting, we ensure that we have + # exactly 3 segments + _, payload, _ = logout_token.split(".") + except ValueError: + raise SynapseError(400, "Invalid logout_token in request") + + try: + payload_bytes = unpaddedbase64.decode_base64(payload) + claims = json_decoder.decode(payload_bytes.decode("utf-8")) + except (json.JSONDecodeError, binascii.Error, UnicodeError): + raise SynapseError(400, "Invalid logout_token payload in request") + + try: + # Let's extract the iss and aud claims + iss = claims["iss"] + aud = claims["aud"] + # The aud claim can be either a string or a list of string. Here we + # normalize it as a list of strings. + if isinstance(aud, str): + aud = [aud] + + # Check that we have the right types for the aud and the iss claims + if not isinstance(iss, str) or not isinstance(aud, list): + raise TypeError() + for a in aud: + if not isinstance(a, str): + raise TypeError() + + # At this point we properly checked both claims types + issuer: str = iss + audience: List[str] = aud + except (TypeError, KeyError): + raise SynapseError(400, "Invalid issuer/audience in logout_token") + + # Now that we know the audience and the issuer, we can figure out from + # what provider it is coming from + oidc_provider: Optional[OidcProvider] = None + for provider in self._providers.values(): + if provider.issuer == issuer and provider.client_id in audience: + oidc_provider = provider + break + + if oidc_provider is None: + raise SynapseError(400, "Could not find the OP that issued this event") + + # Ask the provider to handle the logout request. + await oidc_provider.handle_backchannel_logout(request, logout_token) + class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" @@ -342,6 +435,7 @@ class OidcProvider: self.idp_brand = provider.idp_brand self._sso_handler = hs.get_sso_handler() + self._device_handler = hs.get_device_handler() self._sso_handler.register_identity_provider(self) @@ -400,6 +494,41 @@ class OidcProvider: # If we're not using userinfo, we need a valid jwks to validate the ID token m.validate_jwks_uri() + if self._config.backchannel_logout_enabled: + if not m.get("backchannel_logout_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled for issuer %r" + "but it does not advertise support for it", + self.issuer, + ) + + elif not m.get("backchannel_logout_session_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled and supported " + "by issuer %r but it might not send a session ID with " + "logout tokens, which is required for the logouts to work", + self.issuer, + ) + + if not self._config.backchannel_logout_ignore_sub: + # If OIDC backchannel logouts are enabled, the provider mapping provider + # should use the `sub` claim. We verify that by mapping a dumb user and + # see if we get back the sub claim + user = UserInfo({"sub": "thisisasubject"}) + try: + subject = self._user_mapping_provider.get_remote_user_id(user) + if subject != user["sub"]: + raise ValueError("Unexpected subject") + except Exception: + logger.warning( + f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} " + "but it looks like the configured `user_mapping_provider` " + "does not use the `sub` claim as subject. If it is the case, " + "and you want Synapse to ignore the `sub` claim in OIDC " + "Back-Channel Logouts, set `backchannel_logout_ignore_sub` " + "to `true` in the issuer config." + ) + @property def _uses_userinfo(self) -> bool: """Returns True if the ``userinfo_endpoint`` should be used. @@ -415,6 +544,16 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) + @property + def issuer(self) -> str: + """The issuer identifying this provider.""" + return self._config.issuer + + @property + def client_id(self) -> str: + """The client_id used when interacting with this provider.""" + return self._config.client_id + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: """Return the provider metadata. @@ -662,6 +801,59 @@ class OidcProvider: return UserInfo(resp) + async def _verify_jwt( + self, + alg_values: List[str], + token: str, + claims_cls: Type[C], + claims_options: Optional[dict] = None, + claims_params: Optional[dict] = None, + ) -> C: + """Decode and validate a JWT, re-fetching the JWKS as needed. + + Args: + alg_values: list of `alg` values allowed when verifying the JWT. + token: the JWT. + claims_cls: the JWTClaims class to use to validate the claims. + claims_options: dict of options passed to the `claims_cls` constructor. + claims_params: dict of params passed to the `claims_cls` constructor. + + Returns: + The decoded claims in the JWT. + """ + jwt = JsonWebToken(alg_values) + + logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token) + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + + logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) + + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew + return claims + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. @@ -675,13 +867,13 @@ class OidcProvider: The decoded claims in the ID token. """ id_token = token.get("id_token") - logger.debug("Attempting to decode JWT id_token %r", id_token) # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy assert id_token is not None metadata = await self.load_metadata() + claims_params = { "nonce": nonce, "client_id": self._client_auth.client_id, @@ -691,38 +883,17 @@ class OidcProvider: # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) - - claim_options = {"iss": {"values": [metadata["issuer"]]}} + claims_options = {"iss": {"values": [metadata["issuer"]]}} - # Try to decode the keys in cache first, then retry by forcing the keys - # to be reloaded - jwk_set = await self.load_jwks() - try: - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - except ValueError: - logger.info("Reloading JWKS after decode error") - jwk_set = await self.load_jwks(force=True) # try reloading the jwks - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - - logger.debug("Decoded id_token JWT %r; validating", claims) + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - claims.validate( - now=self._clock.time(), leeway=120 - ) # allows 2 min of clock skew + claims = await self._verify_jwt( + alg_values=alg_values, + token=id_token, + claims_cls=CodeIDToken, + claims_options=claims_options, + claims_params=claims_params, + ) return claims @@ -1043,6 +1214,146 @@ class OidcProvider: # to be strings. return str(remote_user_id) + async def handle_backchannel_logout( + self, request: SynapseRequest, logout_token: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + The OIDC Provider posts a logout token to this endpoint when a user + session ends. That token is a JWT signed with the same keys as + ID tokens. The OpenID Connect Back-Channel Logout draft explains how to + validate the JWT and figure out what session to end. + + Args: + request: The request to respond to + logout_token: The logout token (a JWT) extracted from the request body + """ + # Back-Channel Logout can be disabled in the config, hence this check. + # This is not that important for now since Synapse is registered + # manually to the OP, so not specifying the backchannel-logout URI is + # as effective than disabling it here. It might make more sense if we + # support dynamic registration in Synapse at some point. + if not self._config.backchannel_logout_enabled: + logger.warning( + f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config" + ) + + # TODO: this responds with a 400 status code, which is what the OIDC + # Back-Channel Logout spec expects, but spec also suggests answering with + # a JSON object, with the `error` and `error_description` fields set, which + # we are not doing here. + # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse + raise SynapseError( + 400, "OpenID Connect Back-Channel Logout is disabled for this provider" + ) + + metadata = await self.load_metadata() + + # As per OIDC Back-Channel Logout 1.0 sec. 2.4: + # A Logout Token MUST be signed and MAY also be encrypted. The same + # keys are used to sign and encrypt Logout Tokens as are used for ID + # Tokens. If the Logout Token is encrypted, it SHOULD replicate the + # iss (issuer) claim in the JWT Header Parameters, as specified in + # Section 5.3 of [JWT]. + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + # As per sec. 2.6: + # 3. Validate the iss, aud, and iat Claims in the same way they are + # validated in ID Tokens. + # Which means the audience should contain Synapse's client_id and the + # issuer should be the IdP issuer + claims_options = { + "iss": {"values": [metadata["issuer"]]}, + "aud": {"values": [self.client_id]}, + } + + try: + claims = await self._verify_jwt( + alg_values=alg_values, + token=logout_token, + claims_cls=LogoutToken, + claims_options=claims_options, + ) + except JoseError: + logger.exception("Invalid logout_token") + raise SynapseError(400, "Invalid logout_token") + + # As per sec. 2.6: + # 4. Verify that the Logout Token contains a sub Claim, a sid Claim, + # or both. + # 5. Verify that the Logout Token contains an events Claim whose + # value is JSON object containing the member name + # http://schemas.openid.net/event/backchannel-logout. + # 6. Verify that the Logout Token does not contain a nonce Claim. + # This is all verified by the LogoutToken claims class, so at this + # point the `sid` claim exists and is a string. + sid: str = claims.get("sid") + + # If the `sub` claim was included in the logout token, we check that it matches + # that it matches the right user. We can have cases where the `sub` claim is not + # the ID saved in database, so we let admins disable this check in config. + sub: Optional[str] = claims.get("sub") + expected_user_id: Optional[str] = None + if sub is not None and not self._config.backchannel_logout_ignore_sub: + expected_user_id = await self._store.get_user_by_external_id( + self.idp_id, sub + ) + + # Invalidate any running user-mapping sessions, in-flight login tokens and + # active devices + await self._sso_handler.revoke_sessions_for_provider_session_id( + auth_provider_id=self.idp_id, + auth_provider_session_id=sid, + expected_user_id=expected_user_id, + ) + + request.setResponseCode(200) + request.setHeader(b"Cache-Control", b"no-cache, no-store") + request.setHeader(b"Pragma", b"no-cache") + finish_request(request) + + +class LogoutToken(JWTClaims): + """ + Holds and verify claims of a logout token, as per + https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken + """ + + REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] + + def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + """Validate everything in claims payload.""" + super().validate(now, leeway) + self.validate_sid() + self.validate_events() + self.validate_nonce() + + def validate_sid(self) -> None: + """Ensure the sid claim is present""" + sid = self.get("sid") + if not sid: + raise MissingClaimError("sid") + + if not isinstance(sid, str): + raise InvalidClaimError("sid") + + def validate_nonce(self) -> None: + """Ensure the nonce claim is absent""" + if "nonce" in self: + raise InvalidClaimError("nonce") + + def validate_events(self) -> None: + """Ensure the events claim is present and with the right value""" + events = self.get("events") + if not events: + raise MissingClaimError("events") + + if not isinstance(events, dict): + raise InvalidClaimError("events") + + if "http://schemas.openid.net/event/backchannel-logout" not in events: + raise InvalidClaimError("events") + # number of seconds a newly-generated client secret should be valid for CLIENT_SECRET_VALIDITY_SECONDS = 3600 @@ -1112,6 +1423,7 @@ class JwtClientSecret: logger.info( "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload ) + jwt = JsonWebToken(header["alg"]) self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret_replacement_time = ( expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS @@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict): emails: List[str] -C = TypeVar("C") - - class OidcMappingProvider(Generic[C]): """A mapping provider maps a UserInfo object to user attributes. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 5943f08e91..749d7e93b0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -191,6 +191,7 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() @@ -1026,6 +1027,76 @@ class SsoHandler: return True + async def revoke_sessions_for_provider_session_id( + self, + auth_provider_id: str, + auth_provider_session_id: str, + expected_user_id: Optional[str] = None, + ) -> None: + """Revoke any devices and in-flight logins tied to a provider session. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + auth_provider_session_id: The session ID from the provider to logout + expected_user_id: The user we're expecting to logout. If set, it will ignore + sessions belonging to other users and log an error. + """ + # Invalidate any running user-mapping sessions + to_delete = [] + for session_id, session in self._username_mapping_sessions.items(): + if ( + session.auth_provider_id == auth_provider_id + and session.auth_provider_session_id == auth_provider_session_id + ): + to_delete.append(session_id) + + for session_id in to_delete: + logger.info("Revoking mapping session %s", session_id) + del self._username_mapping_sessions[session_id] + + # Invalidate any in-flight login tokens + await self._store.invalidate_login_tokens_by_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # Fetch any device(s) in the store associated with the session ID. + devices = await self._store.get_devices_by_auth_provider_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # We have no guarantee that all the devices of that session are for the same + # `user_id`. Hence, we have to iterate over the list of devices and log them out + # one by one. + for device in devices: + user_id = device["user_id"] + device_id = device["device_id"] + + # If the user_id associated with that device/session is not the one we got + # out of the `sub` claim, skip that device and show log an error. + if expected_user_id is not None and user_id != expected_user_id: + logger.error( + "Received a logout notification from SSO provider " + f"{auth_provider_id!r} for the user {expected_user_id!r}, but with " + f"a session ID ({auth_provider_session_id!r}) which belongs to " + f"{user_id!r}. This may happen when the SSO provider user mapper " + "uses something else than the standard attribute as mapping ID. " + "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` " + "in the provider config if that is the case." + ) + continue + + logger.info( + "Logging out %r (device %r) via SSO (%r) logout notification (session %r).", + user_id, + device_id, + auth_provider_id, + auth_provider_session_id, + ) + await self._device_handler.delete_devices(user_id, [device_id]) + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index 81fec39659..e4b28ce3df 100644 --- a/synapse/rest/synapse/client/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING from twisted.web.resource import Resource +from synapse.rest.synapse.client.oidc.backchannel_logout_resource import ( + OIDCBackchannelLogoutResource, +) from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource if TYPE_CHECKING: @@ -29,6 +32,7 @@ class OIDCResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) + self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs)) __all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py new file mode 100644 index 0000000000..e07e76855a --- /dev/null +++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py @@ -0,0 +1,35 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class OIDCBackchannelLogoutResource(DirectServeJsonResource): + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_POST(self, request: SynapseRequest) -> None: + await self._oidc_handler.handle_backchannel_logout(request) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0255295317..5167089e03 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._clock.time_msec(), ) + async def invalidate_login_tokens_by_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> None: + """Invalidate login tokens with the given IdP session ID. + + Args: + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider + """ + await self.db_pool.simple_update( + table="login_tokens", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + updatevalues={"used_ts": self._clock.time_msec()}, + desc="invalidate_login_tokens_by_session_id", + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index ebf653d018..847294dc8e 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType -from synapse.api.errors import Codes +from synapse.api.errors import Codes, SynapseError from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -32,8 +33,8 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel +from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER +from tests.server import FakeChannel, make_request from tests.unittest import override_config, skip_unless @@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token: str) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. @@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.reactor.advance(59.0) # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 61 s (just past 1 minute, the time of expiry) self.reactor.advance(2.0) # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 599 s (just shy of 10 minutes, the time of expiry) self.reactor.advance(599.0 - 61.0) # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 601 s (just past 10 minutes, the time of expiry) self.reactor.advance(2.0) # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami( + nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} @@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # and no refresh token self.assertEqual(_table_length("access_tokens"), 0) self.assertEqual(_table_length("refresh_tokens"), 0) + + +def oidc_config( + id: str, with_localpart_template: bool, **kwargs: Any +) -> Dict[str, Any]: + """Sample OIDC provider config used in backchannel logout tests. + + Args: + id: IDP ID for this provider + with_localpart_template: Set to `true` to have a default localpart_template in + the `user_mapping_provider` config and skip the user mapping session + **kwargs: rest of the config + + Returns: + A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of + the HS config + """ + config: Dict[str, Any] = { + "idp_id": id, + "idp_name": id, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + } + + if with_localpart_template: + config["user_mapping_provider"] = { + "config": {"localpart_template": "{{ user.sub }}"} + } + else: + config["user_mapping_provider"] = {"config": {}} + + config.update(kwargs) + + return config + + +@skip_unless(HAS_OIDC, "Requires OIDC") +class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def submit_logout_token(self, logout_token: str) -> FakeChannel: + return self.make_request( + "POST", + "/_synapse/client/oidc/backchannel_logout", + content=f"logout_token={logout_token}", + content_is_form=True, + ) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_simple_logout(self) -> None: + """ + Receiving a logout token should logout the user + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + self.assertNotEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = fake_oidc_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = fake_oidc_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_login(self) -> None: + """ + It should revoke login tokens when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # expect a confirmation page + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now try to exchange the login token + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + # It should have failed + self.assertEqual(channel.code, 403) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=False, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_mapping(self) -> None: + """ + It should stop ongoing user mapping session when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # Expect a user mapping page + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + + # We should have a user_mapping_session cookie + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + user_mapping_session_id = cookies["username_mapping_session"] + + # Getting that session should not raise + session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + self.assertIsNotNone(session) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now it should raise + with self.assertRaises(SynapseError): + self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=False, + ) + ] + } + ) + def test_disabled(self) -> None: + """ + Receiving a logout token should do nothing if it is disabled in the config + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_no_sid(self) -> None: + """ + Receiving a logout token without `sid` during the login should do nothing + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=False + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + "first", + issuer="https://first-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + oidc_config( + "second", + issuer="https://second-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + ] + } + ) + def test_multiple_providers(self) -> None: + """ + It should be able to distinguish login tokens from two different IdPs + """ + first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/") + second_server = self.helper.fake_oidc_server( + issuer="https://second-issuer.com/" + ) + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + first_server, user, with_sid=True, idp_id="oidc-first" + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + second_server, user, with_sid=True, idp_id="oidc-second" + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # `sid` in the fake providers are generated by a counter, so the first grant of + # each provider should give the same SID + self.assertEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = first_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = second_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 967d229223..706399fae5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -553,6 +553,34 @@ class RestHelper: return channel.json_body + def whoami( + self, + access_token: str, + expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK, + ) -> JsonDict: + """Perform a 'whoami' request, which can be a quick way to check for access + token validity + + Args: + access_token: The user token to use during the request + expect_code: The return code to expect from attempting the whoami request + """ + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "account/whoami", + access_token=access_token, + ) + + assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: """Create a ``FakeOidcServer``. @@ -572,6 +600,7 @@ class RestHelper: fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, + idp_id: Optional[str] = None, expected_status: int = 200, ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC @@ -588,7 +617,11 @@ class RestHelper: client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} channel, grant = self.auth_via_oidc( - fake_server, userinfo, client_redirect_url, with_sid=with_sid + fake_server, + userinfo, + client_redirect_url, + with_sid=with_sid, + idp_id=idp_id, ) # expect a confirmation page @@ -623,6 +656,7 @@ class RestHelper: client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, with_sid: bool = False, + idp_id: Optional[str] = None, ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. @@ -648,6 +682,7 @@ class RestHelper: ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. with_sid: if True, generates a random `sid` (OIDC session ID) + idp_id: if set, explicitely chooses one specific IDP Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -665,7 +700,9 @@ class RestHelper: oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) else: # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + oauth_uri = self.initiate_sso_login( + client_redirect_url, cookies, idp_id=idp_id + ) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -742,7 +779,10 @@ class RestHelper: return channel, grant def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + self, + client_redirect_url: Optional[str], + cookies: MutableMapping[str, str], + idp_id: Optional[str] = None, ) -> str: """Make a request to the login-via-sso redirect endpoint, and return the target @@ -753,6 +793,7 @@ class RestHelper: client_redirect_url: the client redirect URL to pass to the login redirect endpoint cookies: any cookies returned will be added to this dict + idp_id: if set, explicitely chooses one specific IDP Returns: the URI that the client gets redirected to (ie, the SSO server) @@ -761,6 +802,12 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url + uri = "/_matrix/client/r0/login/sso/redirect" + if idp_id is not None: + uri = f"{uri}/{idp_id}" + + uri = f"{uri}?{urllib.parse.urlencode(params)}" + # hit the redirect url (which should redirect back to the redirect url. This # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. @@ -768,7 +815,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + uri, ) assert channel.code == 302 diff --git a/tests/server.py b/tests/server.py index 8b1d186219..b1730fcc8d 100644 --- a/tests/server.py +++ b/tests/server.py @@ -362,6 +362,12 @@ def make_request( # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) + # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded + # bodies if the Content-Length header is missing + req.requestHeaders.addRawHeader( + b"Content-Length", str(len(content)).encode("ascii") + ) + if access_token: req.requestHeaders.addRawHeader( b"Authorization", b"Bearer " + access_token.encode("ascii") diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index de134bbc89..1461d23ee8 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -51,6 +51,8 @@ class FakeOidcServer: get_userinfo_handler: Mock post_token_handler: Mock + sid_counter: int = 0 + def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet @@ -146,7 +148,7 @@ class FakeOidcServer: return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self._clock.time() + now = int(self._clock.time()) id_token = { **grant.userinfo, "iss": self.issuer, @@ -166,6 +168,26 @@ class FakeOidcServer: return self._sign(id_token) + def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str: + now = int(self._clock.time()) + logout_token = { + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "jti": random_string(10), + "events": { + "http://schemas.openid.net/event/backchannel-logout": {}, + }, + } + + if grant.sid is not None: + logout_token["sid"] = grant.sid + + if "sub" in grant.userinfo: + logout_token["sub"] = grant.userinfo["sub"] + + return self._sign(logout_token) + def id_token_override(self, overrides: dict): """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -183,7 +205,8 @@ class FakeOidcServer: code = random_string(10) sid = None if with_sid: - sid = random_string(10) + sid = str(self.sid_counter) + self.sid_counter += 1 grant = FakeAuthorizationGrant( userinfo=userinfo, -- cgit 1.5.1 From dbfc9b803ee32f7b31c2b5ccbc53a1bfcaa95983 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 20:31:43 +0000 Subject: Fix dehydrated device REST checks (#14336) --- changelog.d/14336.bugfix | 1 + synapse/rest/client/devices.py | 5 ++--- tests/rest/client/test_devices.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14336.bugfix (limited to 'synapse') diff --git a/changelog.d/14336.bugfix b/changelog.d/14336.bugfix new file mode 100644 index 0000000000..d44ff1bbc7 --- /dev/null +++ b/changelog.d/14336.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70 where clients were unable to PUT new [dehydrated devices](https://github.com/matrix-org/matrix-spec-proposals/pull/2697). diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 90828c95c4..8f3cbd4ea2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -231,7 +231,7 @@ class DehydratedDeviceServlet(RestServlet): } } - PUT /org.matrix.msc2697/dehydrated_device + PUT /org.matrix.msc2697.v2/dehydrated_device Content-Type: application/json { @@ -271,7 +271,6 @@ class DehydratedDeviceServlet(RestServlet): raise errors.NotFoundError("No dehydrated device available") class PutBody(RequestBodyModel): - device_id: StrictStr device_data: DehydratedDeviceDataModel initial_device_display_name: Optional[StrictStr] @@ -281,7 +280,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission.device_data, + submission.device_data.dict(), submission.initial_device_display_name, ) return 200, {"device_id": device_id} diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index aa98222434..d80eea17d3 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase): self.reactor.advance(43200) self.get_success(self.handler.get_device(user_id, "abc")) self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) + + +class DehydratedDeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + + def test_PUT(self) -> None: + """Sanity-check that we can PUT a dehydrated device. + + Detects https://github.com/matrix-org/synapse/issues/14334. + """ + alice = self.register_user("alice", "correcthorse") + token = self.login(alice, "correcthorse") + + # Have alice update their device list + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device", + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device", + } + }, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + device_id = channel.json_body.get("device_id") + self.assertIsInstance(device_id, str) -- cgit 1.5.1 From b922b54b6143f13c0786a18fcbb5f55724ea72fc Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 1 Nov 2022 10:30:43 +0000 Subject: Fix type annotation causing import time error in the Complement forking launcher. (#14084) Co-authored-by: David Robertson --- changelog.d/14084.misc | 1 + synapse/app/complement_fork_starter.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14084.misc (limited to 'synapse') diff --git a/changelog.d/14084.misc b/changelog.d/14084.misc new file mode 100644 index 0000000000..988e55f437 --- /dev/null +++ b/changelog.d/14084.misc @@ -0,0 +1 @@ +Fix type annotation causing import time error in the Complement forking launcher. \ No newline at end of file diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index b22f315453..8c0f4a57e7 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -55,13 +55,13 @@ import os import signal import sys from types import FrameType -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional from twisted.internet.main import installReactor # a list of the original signal handlers, before we installed our custom ones. # We restore these in our child processes. -_original_signal_handlers: dict[int, Any] = {} +_original_signal_handlers: Dict[int, Any] = {} class ProxiedReactor: -- cgit 1.5.1 From 9473ebb9e7db9e3f71b341f72ae004db3a0144b8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 11:47:09 +0000 Subject: Revert "Fix event size checks (#13710)" This reverts commit fab495a9e1442d99e922367f65f41de5eaa488eb. As noted in https://github.com/matrix-org/synapse/pull/13710#issuecomment-1298396007: > We want to see this change land for the protocol's sake (and plan to un-revert it) but want to give this a little more time before releasing this. --- changelog.d/13710.bugfix | 1 - synapse/event_auth.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) delete mode 100644 changelog.d/13710.bugfix (limited to 'synapse') diff --git a/changelog.d/13710.bugfix b/changelog.d/13710.bugfix deleted file mode 100644 index 4c318d15f5..0000000000 --- a/changelog.d/13710.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where Synapse would count codepoints instead of bytes when validating the size of some fields. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5036604036..bab31e33c5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -342,15 +342,15 @@ def check_state_dependent_auth_rules( def _check_size_limits(event: "EventBase") -> None: - if len(event.user_id.encode("utf-8")) > 255: + if len(event.user_id) > 255: raise EventSizeError("'user_id' too large") - if len(event.room_id.encode("utf-8")) > 255: + if len(event.room_id) > 255: raise EventSizeError("'room_id' too large") - if event.is_state() and len(event.state_key.encode("utf-8")) > 255: + if event.is_state() and len(event.state_key) > 255: raise EventSizeError("'state_key' too large") - if len(event.type.encode("utf-8")) > 255: + if len(event.type) > 255: raise EventSizeError("'type' too large") - if len(event.event_id.encode("utf-8")) > 255: + if len(event.event_id) > 255: raise EventSizeError("'event_id' too large") if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: raise EventSizeError("event too large") -- cgit 1.5.1 From 2bd7f3eeab1a4818359c9f585b660ff3f3d8bc6c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 15:02:39 +0000 Subject: Allow PUT/GET of aliases during faster join (#14292) without blocking on full state. --- changelog.d/14292.bugfix | 1 + synapse/handlers/directory.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14292.bugfix (limited to 'synapse') diff --git a/changelog.d/14292.bugfix b/changelog.d/14292.bugfix new file mode 100644 index 0000000000..4ed92f5cf2 --- /dev/null +++ b/changelog.d/14292.bugfix @@ -0,0 +1 @@ +Faster joins: do not block creation of or queries for room aliases during the resync. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index d52ebada6b..2ea52257cb 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -85,7 +85,7 @@ class DirectoryHandler: # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = await self._storage_controllers.state.get_current_hosts_in_room( + servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) @@ -290,7 +290,7 @@ class DirectoryHandler: Codes.NOT_FOUND, ) - extra_servers = await self._storage_controllers.state.get_current_hosts_in_room( + extra_servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) servers_set = set(extra_servers) | set(servers) -- cgit 1.5.1 From d4fac8a3e27ab3e133c5e5ac603c8d937a1fd86c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 19:20:35 +0000 Subject: Fix typo in #13320 which could cause log spam (#14347) --- changelog.d/14347.bugfix | 1 + synapse/federation/federation_client.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14347.bugfix (limited to 'synapse') diff --git a/changelog.d/14347.bugfix b/changelog.d/14347.bugfix new file mode 100644 index 0000000000..91975757ae --- /dev/null +++ b/changelog.d/14347.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.64.0rc1 which could cause log spam when fetching events from other homeservers. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fa225182be..c4c0bc7315 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -465,7 +465,7 @@ class FederationClient(FederationBase): pdu_attempts[destination] = now logger.info( - "get_pdu(event_id=): Failed to get PDU from %s because %s", + "get_pdu(event_id=%s): Failed to get PDU from %s because %s", event_id, destination, e, -- cgit 1.5.1 From 6546308c1e7d3eff316631a5909151dc6c7a9e1e Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 2 Nov 2022 17:33:45 +0000 Subject: Disable legacy Prometheus metric names by default. They can still be re-enabled for now, but they will be removed altogether in Synapse 1.73.0. (#14353) --- CHANGES.md | 9 +++++++++ changelog.d/14353.removal | 1 + docs/upgrade.md | 16 ++++++++++++++++ docs/usage/configuration/config_documentation.md | 4 ++-- synapse/config/metrics.py | 2 +- 5 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14353.removal (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index 113ad0d1ee..6bafdd3fad 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +Synapse (Next) (2022-11-01) +========================= + +Please note that, as announced in the release notes for Synapse 1.69.0, legacy Prometheus metric names are now disabled by default. +They will be removed altogether in Synapse 1.73.0. +If not already done, server administrators should update their dashboards and alerting rules to avoid using the deprecated metric names. +See the [upgrade notes](https://matrix-org.github.io/synapse/v1.71/upgrade.html#upgrading-to-v1710) for more details. + + Synapse 1.71.0rc1 (2022-11-01) ============================== diff --git a/changelog.d/14353.removal b/changelog.d/14353.removal new file mode 100644 index 0000000000..fc42aa9106 --- /dev/null +++ b/changelog.d/14353.removal @@ -0,0 +1 @@ +Disable legacy Prometheus metric names by default. They can still be re-enabled for now, but they will be removed altogether in Synapse 1.73.0. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index f095bbc3a6..41b06cc253 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -116,6 +116,22 @@ local users and some remote users is why the spec was changed/clarified and this caveat is no longer supported. +## Legacy Prometheus metric names are now disabled by default + +Synapse v1.71.0 disables legacy Prometheus metric names by default. +For administrators that still rely on them and have not yet had chance to update their +uses of the metrics, it's still possible to specify `enable_legacy_metrics: true` in +the configuration to re-enable them temporarily. + +Synapse v1.73.0 will **remove legacy metric names altogether** and at that point, +it will no longer be possible to re-enable them. + +If you do not use metrics or you have already updated your Grafana dashboard(s), +Prometheus console(s) and alerting rule(s), there is no action needed. + +See [v1.69.0: Deprecation of legacy Prometheus metric names](#deprecation-of-legacy-prometheus-metric-names). + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 44358faf59..9a6bd08d01 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2441,8 +2441,8 @@ enable_metrics: true Set to `true` to publish both legacy and non-legacy Prometheus metric names, or to `false` to only publish non-legacy Prometheus metric names. -Defaults to `true`. Has no effect if `enable_metrics` is `false`. -**In Synapse v1.71.0, this will default to `false` before being removed in Synapse v1.73.0.** +Defaults to `false`. Has no effect if `enable_metrics` is `false`. +**In Synapse v1.67.0 up to and including Synapse v1.70.1, this defaulted to `true`.** Legacy metric names include: - metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index bb065f9f2f..6034a0346e 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -43,7 +43,7 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) - self.enable_legacy_metrics = config.get("enable_legacy_metrics", True) + self.enable_legacy_metrics = config.get("enable_legacy_metrics", False) self.report_stats = config.get("report_stats", None) self.report_stats_endpoint = config.get( -- cgit 1.5.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'synapse') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.5.1 From a4b1f6456276e62b3f4d6b060c289b6413b8a5c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Nov 2022 18:43:51 +0200 Subject: Fix /refresh endpoint version (#14364) --- changelog.d/14364.bugfix | 1 + synapse/rest/client/login.py | 2 +- tests/rest/client/test_auth.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14364.bugfix (limited to 'synapse') diff --git a/changelog.d/14364.bugfix b/changelog.d/14364.bugfix new file mode 100644 index 0000000000..514bf859bb --- /dev/null +++ b/changelog.d/14364.bugfix @@ -0,0 +1 @@ +Fix refresh token endpoint to be under /r0 and /v3 instead of /v1. Contributed by Tulir @ Beeper. diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 7774f1967d..05706b598c 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -536,7 +536,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) + PATTERNS = client_patterns("/refresh$") def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 847294dc8e..208ec44829 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -635,7 +635,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ return self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": refresh_token}, ) @@ -724,7 +724,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -765,7 +765,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -1002,7 +1002,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1012,7 +1012,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1022,7 +1022,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1056,7 +1056,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1068,7 +1068,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( -- cgit 1.5.1 From 8bcdd712b8ba471b3489d41e569276677cf6c2bd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 4 Nov 2022 18:43:14 +0000 Subject: Bump flake8-bugbear from 22.9.23 to 22.10.27 (#14329) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: GitHub Actions Co-authored-by: Olivier Wilkinson (reivilibre) --- changelog.d/14329.misc | 1 + poetry.lock | 10 +++++----- synapse/handlers/presence.py | 6 ++++-- synapse/server.py | 2 +- synapse/storage/_base.py | 2 +- 5 files changed, 12 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14329.misc (limited to 'synapse') diff --git a/changelog.d/14329.misc b/changelog.d/14329.misc new file mode 100644 index 0000000000..2f6bbd3af7 --- /dev/null +++ b/changelog.d/14329.misc @@ -0,0 +1 @@ +Bump flake8-bugbear from 22.9.23 to 22.10.27. diff --git a/poetry.lock b/poetry.lock index b945463299..f6e462e6ae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -260,18 +260,18 @@ pyflakes = ">=2.4.0,<2.5.0" [[package]] name = "flake8-bugbear" -version = "22.9.23" +version = "22.10.27" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] attrs = ">=19.2.0" flake8 = ">=3.0.0" [package.extras] -dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit"] +dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit", "tox"] [[package]] name = "flake8-comprehensions" @@ -1829,8 +1829,8 @@ flake8 = [ {file = "flake8-4.0.1.tar.gz", hash = "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d"}, ] flake8-bugbear = [ - {file = "flake8-bugbear-22.9.23.tar.gz", hash = "sha256:17b9623325e6e0dcdcc80ed9e4aa811287fcc81d7e03313b8736ea5733759937"}, - {file = "flake8_bugbear-22.9.23-py3-none-any.whl", hash = "sha256:cd2779b2b7ada212d7a322814a1e5651f1868ab0d3f24cc9da66169ab8fda474"}, + {file = "flake8-bugbear-22.10.27.tar.gz", hash = "sha256:a6708608965c9e0de5fff13904fed82e0ba21ac929fe4896459226a797e11cd5"}, + {file = "flake8_bugbear-22.10.27-py3-none-any.whl", hash = "sha256:6ad0ab754507319060695e2f2be80e6d8977cfcea082293089a9226276bd825d"}, ] flake8-comprehensions = [ {file = "flake8-comprehensions-3.8.0.tar.gz", hash = "sha256:8e108707637b1d13734f38e03435984f6b7854fa6b5a4e34f93e69534be8e521"}, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2670e561d7..0066d63987 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -256,7 +256,7 @@ class BasePresenceHandler(abc.ABC): with the app. """ - async def update_external_syncs_row( + async def update_external_syncs_row( # noqa: B027 (no-op by design) self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int ) -> None: """Update the syncing users for an external process as a delta. @@ -272,7 +272,9 @@ class BasePresenceHandler(abc.ABC): sync_time_msec: Time in ms when the user was last syncing """ - async def update_external_syncs_clear(self, process_id: str) -> None: + async def update_external_syncs_clear( # noqa: B027 (no-op by design) + self, process_id: str + ) -> None: """Marks all users that had been marked as syncing by a given process as offline. diff --git a/synapse/server.py b/synapse/server.py index df3a1cb405..c4e025af22 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -315,7 +315,7 @@ class HomeServer(metaclass=abc.ABCMeta): if self.config.worker.run_background_tasks: self.setup_background_tasks() - def start_listening(self) -> None: + def start_listening(self) -> None: # noqa: B027 (no-op by design) """Start the HTTP, manhole, metrics, etc listeners Does nothing in this base class; overridden in derived classes to start the diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bf42aeb8d1..69abf6fa87 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -50,7 +50,7 @@ class SQLBaseStore(metaclass=ABCMeta): self.external_cached_functions: Dict[str, CachedFunction] = {} - def process_replication_rows( + def process_replication_rows( # noqa: B027 (no-op by design) self, stream_name: str, instance_name: str, -- cgit 1.5.1 From e980982b59dea38ec10a5c58993d09e02f845d28 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 7 Nov 2022 13:49:31 +0000 Subject: Do not reject `/sync` requests with unrecognised filter fields (#14369) For forward compatibility, Synapse needs to ignore fields it does not recognise instead of raising an error. Fixes #14365. Signed-off-by: Sean Quah --- changelog.d/14369.bugfix | 1 + synapse/api/filtering.py | 8 ++++---- tests/api/test_filtering.py | 21 +++++++++++++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14369.bugfix (limited to 'synapse') diff --git a/changelog.d/14369.bugfix b/changelog.d/14369.bugfix new file mode 100644 index 0000000000..e6709f4eec --- /dev/null +++ b/changelog.d/14369.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would raise an error when encountering an unrecognised field in a `/sync` filter, instead of ignoring it for forward compatibility. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 26be377d03..a9888381b4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer FILTER_SCHEMA = { - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility "type": "object", "properties": { "limit": {"type": "number"}, @@ -63,7 +63,7 @@ FILTER_SCHEMA = { } ROOM_FILTER_SCHEMA = { - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility "type": "object", "properties": { "not_rooms": {"$ref": "#/definitions/room_id_array"}, @@ -77,7 +77,7 @@ ROOM_FILTER_SCHEMA = { } ROOM_EVENT_FILTER_SCHEMA = { - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility "type": "object", "properties": { "limit": {"type": "number"}, @@ -143,7 +143,7 @@ USER_FILTER_SCHEMA = { }, }, }, - "additionalProperties": False, + "additionalProperties": True, # Allow new fields for forward compatibility } diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a82c4eed86..d5524d296e 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -46,19 +46,36 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): + # See USER_FILTER_SCHEMA for the filter schema. invalid_filters = [ - {"boom": {}}, + # `account_data` must be a dictionary {"account_data": "Hello World"}, + # `event_fields` entries must not contain backslashes {"event_fields": [r"\\foo"]}, - {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, + # `event_format` must be "client" or "federation" {"event_format": "other"}, + # `not_rooms` must contain valid room IDs {"room": {"not_rooms": ["#foo:pik-test"]}}, + # `senders` must contain valid user IDs {"presence": {"senders": ["@bar;pik.test.com"]}}, ] for filter in invalid_filters: with self.assertRaises(SynapseError): self.filtering.check_valid_filter(filter) + def test_ignores_unknown_filter_fields(self): + # For forward compatibility, we must ignore unknown filter fields. + # See USER_FILTER_SCHEMA for the filter schema. + filters = [ + {"org.matrix.msc9999.future_option": True}, + {"presence": {"org.matrix.msc9999.future_option": True}}, + {"room": {"org.matrix.msc9999.future_option": True}}, + {"room": {"timeline": {"org.matrix.msc9999.future_option": True}}}, + ] + for filter in filters: + self.filtering.check_valid_filter(filter) + # Must not raise. + def test_valid_filters(self): valid_filters = [ { -- cgit 1.5.1 From 2193513346054769080dd8a07586bed652acae60 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 7 Nov 2022 14:28:00 +0000 Subject: Fix background update table-scanning `events` (#14374) When this background update did its last batch, it would try to update all the events that had been inserted since the bgupdate started, which could cause a table-scan. Make sure we limit the update correctly. --- changelog.d/14374.bugfix | 1 + synapse/storage/databases/main/events_bg_updates.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14374.bugfix (limited to 'synapse') diff --git a/changelog.d/14374.bugfix b/changelog.d/14374.bugfix new file mode 100644 index 0000000000..8366cfbf8a --- /dev/null +++ b/changelog.d/14374.bugfix @@ -0,0 +1 @@ +Fix a background database update, introduced in Synapse 1.64.0, which could cause poor database performance. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 6e8aeed7b4..9e31798ab1 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1435,16 +1435,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ), ) - endpoint = None row = txn.fetchone() if row: endpoint = row[0] + else: + # if the query didn't return a row, we must be almost done. We just + # need to go up to the recorded max_stream_ordering. + endpoint = max_stream_ordering_inclusive - where_clause = "stream_ordering > ?" - args = [min_stream_ordering_exclusive] - if endpoint: - where_clause += " AND stream_ordering <= ?" - args.append(endpoint) + where_clause = "stream_ordering > ? AND stream_ordering <= ?" + args = [min_stream_ordering_exclusive, endpoint] # now do the updates. txn.execute( @@ -1458,13 +1458,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) logger.info( - "populated new `events` columns up to %s/%i: updated %i rows", + "populated new `events` columns up to %i/%i: updated %i rows", endpoint, max_stream_ordering_inclusive, txn.rowcount, ) - if endpoint is None: + if endpoint >= max_stream_ordering_inclusive: # we're done return True -- cgit 1.5.1 From 7894251bcea7714b47e3849e509ea717bb18e9f5 Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Nov 2022 13:38:50 -0800 Subject: Correctly create power level event during initial room creation (#14361) --- changelog.d/14361.bugfix | 1 + synapse/handlers/room.py | 25 +++++++++++++++++++++++-- tests/rest/client/test_rooms.py | 4 ++-- 3 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14361.bugfix (limited to 'synapse') diff --git a/changelog.d/14361.bugfix b/changelog.d/14361.bugfix new file mode 100644 index 0000000000..33ba1d92af --- /dev/null +++ b/changelog.d/14361.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.71.0rc1 where the power level event was incorrectly created during initial room creation. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f10cfca073..66a50bca6e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1080,6 +1080,19 @@ class RoomCreationHandler: for_batch: bool, **kwargs: Any, ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + """ + Creates an event and associated event context. + Args: + etype: the type of event to be created + content: content of the event + for_batch: whether the event is being created for batch persisting. If + bool for_batch is true, this will create an event using the prev_event_ids, + and will create an event context for the event using the parameters state_map + and current_state_group, thus these parameters must be provided in this + case if for_batch is True. The subsequently created event and context + are suitable for being batched up and bulk persisted to the database + with other similarly created events. + """ nonlocal depth nonlocal prev_event @@ -1139,13 +1152,21 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + # we need the state group of the membership event as it is the current state group + event_to_state = ( + await self._storage_controllers.state.get_state_group_for_events( + [member_event_id] + ) + ) + current_state_group = event_to_state[member_event_id] + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: power_event, power_context = await create_event( - EventTypes.PowerLevels, pl_content, False + EventTypes.PowerLevels, pl_content, True ) current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) @@ -1194,7 +1215,7 @@ class RoomCreationHandler: pl_event, pl_context = await create_event( EventTypes.PowerLevels, power_level_content, - False, + True, ) current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1084d4ad9d..e919e089cb 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -715,7 +715,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +728,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(36, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From a5fcdea090c2396c30dd07c357ce4d9c90004c34 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 8 Nov 2022 17:17:13 +0000 Subject: Remove support for PostgreSQL 10 (#14392) Signed-off-by: Sean Quah --- .ci/scripts/calculate_jobs.py | 2 +- .github/workflows/tests.yml | 2 +- changelog.d/14392.removal | 1 + docs/upgrade.md | 10 ++++++++++ synapse/storage/engines/postgres.py | 4 ++-- 5 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14392.removal (limited to 'synapse') diff --git a/.ci/scripts/calculate_jobs.py b/.ci/scripts/calculate_jobs.py index c53d4d5ff1..b48174bea2 100755 --- a/.ci/scripts/calculate_jobs.py +++ b/.ci/scripts/calculate_jobs.py @@ -54,7 +54,7 @@ trial_postgres_tests = [ { "python-version": "3.7", "database": "postgres", - "postgres-version": "10", + "postgres-version": "11", "extras": "all", } ] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fea33abd12..2bc237a0ba 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -409,7 +409,7 @@ jobs: matrix: include: - python-version: "3.7" - postgres-version: "10" + postgres-version: "11" - python-version: "3.11" postgres-version: "14" diff --git a/changelog.d/14392.removal b/changelog.d/14392.removal new file mode 100644 index 0000000000..e96b3de2bd --- /dev/null +++ b/changelog.d/14392.removal @@ -0,0 +1 @@ +Remove support for PostgreSQL 10. diff --git a/docs/upgrade.md b/docs/upgrade.md index 41b06cc253..2aa353e496 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,16 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.72.0 + +## Dropping support for PostgreSQL 10 + +In line with our [deprecation policy](deprecation_policy.md), we've dropped +support for PostgreSQL 10, as it is no longer supported upstream. + +This release of Synapse requires PostgreSQL 11+. + + # Upgrading to v1.71.0 ## Removal of the `generate_short_term_login_token` module API method diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 9bf74bbf59..0c4fd88914 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -81,8 +81,8 @@ class PostgresEngine( allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? - if not allow_outdated_version and self._version < 100000: - raise RuntimeError("Synapse requires PostgreSQL 10 or above.") + if not allow_outdated_version and self._version < 110000: + raise RuntimeError("Synapse requires PostgreSQL 11 or above.") with db_conn.cursor() as txn: txn.execute("SHOW SERVER_ENCODING") -- cgit 1.5.1 From e9a4343cb2daa55503bb2a2d1431d83bf9773e68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Nov 2022 09:55:34 -0500 Subject: Drop support for Postgres 10 in full text search code. (#14397) --- changelog.d/14397.removal | 1 + synapse/storage/databases/main/search.py | 50 +++++++++++------------ synapse/storage/engines/postgres.py | 16 -------- tests/storage/test_room_search.py | 69 ++++++++------------------------ 4 files changed, 41 insertions(+), 95 deletions(-) create mode 100644 changelog.d/14397.removal (limited to 'synapse') diff --git a/changelog.d/14397.removal b/changelog.d/14397.removal new file mode 100644 index 0000000000..e96b3de2bd --- /dev/null +++ b/changelog.d/14397.removal @@ -0,0 +1 @@ +Remove support for PostgreSQL 10. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e9588d1755..3fe433f66c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -463,18 +463,17 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term - tsquery_func = self.database_engine.tsquery_func - sql = f""" - SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank, + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank, room_id, event_id FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) + WHERE vector @@ websearch_to_tsquery('english', ?) """ args = [search_query, search_query] + args - count_sql = f""" + count_sql = """ SELECT room_id, count(*) as count FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) + WHERE vector @@ websearch_to_tsquery('english', ?) """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -523,9 +522,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -604,18 +601,17 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term - tsquery_func = self.database_engine.tsquery_func - sql = f""" - SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank, + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, origin_server_ts, stream_ordering, room_id, event_id FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) AND + WHERE vector @@ websearch_to_tsquery('english', ?) AND """ args = [search_query, search_query] + args - count_sql = f""" + count_sql = """ SELECT room_id, count(*) as count FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) AND + WHERE vector @@ websearch_to_tsquery('english', ?) AND """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -686,9 +682,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -714,7 +708,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase], tsquery_func: str + self, search_query: str, events: List[EventBase] ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -725,7 +719,6 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events - tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -758,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( - _to_postgres_options( - { - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - } + query = ( + "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)" + % ( + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) ) ) txn.execute(query, (value, search_query)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 0c4fd88914..719a517336 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,22 +170,6 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True - @property - def tsquery_func(self) -> str: - """ - Selects a tsquery_* func to use. - - Ref: https://www.postgresql.org/docs/current/textsearch-controls.html - - Returns: - The function name. - """ - # Postgres 11 added support for websearch_to_tsquery. - assert self._version is not None - if self._version >= 110000: - return "websearch_to_tsquery" - return "plainto_tsquery" - def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 868b5bee84..ef850daa73 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import List, Tuple from unittest.case import SkipTest -from unittest.mock import PropertyMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -220,10 +219,8 @@ class MessageSearchTest(HomeserverTestCase): PHRASE = "the quick brown fox jumps over the lazy dog" - # Each entry is a search query, followed by either a boolean of whether it is - # in the phrase OR a tuple of booleans: whether it matches using websearch - # and using plain search. - COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # Each entry is a search query, followed by a boolean of whether it is in the phrase. + COMMON_CASES = [ ("nope", False), ("brown", True), ("quick brown", True), @@ -231,13 +228,13 @@ class MessageSearchTest(HomeserverTestCase): ("quick \t brown", True), ("jump", True), ("brown nope", False), - ('"brown quick"', (False, True)), + ('"brown quick"', False), ('"jumps over"', True), - ('"quick fox"', (False, True)), + ('"quick fox"', False), ("nope OR doublenope", False), - ("furphy OR fox", (True, False)), - ("fox -nope", (True, False)), - ("fox -brown", (False, True)), + ("furphy OR fox", True), + ("fox -nope", True), + ("fox -brown", False), ('"fox" quick', True), ('"quick brown', True), ('" quick "', True), @@ -246,11 +243,11 @@ class MessageSearchTest(HomeserverTestCase): # TODO Test non-ASCII cases. # Case that fail on SQLite. - POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + POSTGRES_CASES = [ # SQLite treats NOT as a binary operator. - ("- fox", (False, True)), - ("- nope", (True, False)), - ('"-fox quick', (False, True)), + ("- fox", False), + ("- nope", True), + ('"-fox quick', False), # PostgreSQL skips stop words. ('"the quick brown"', True), ('"over lazy"', True), @@ -275,7 +272,7 @@ class MessageSearchTest(HomeserverTestCase): if isinstance(main_store.database_engine, PostgresEngine): assert main_store.database_engine._version is not None found = main_store.database_engine._version < 140000 - self.COMMON_CASES.append(('"fox quick', (found, True))) + self.COMMON_CASES.append(('"fox quick', found)) def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" @@ -315,16 +312,10 @@ class MessageSearchTest(HomeserverTestCase): ) def _check_test_cases( - self, - store: DataStore, - cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], - index=0, + self, store: DataStore, cases: List[Tuple[str, bool]] ) -> None: # Run all the test cases versus search_msgs for query, expect_to_contain in cases: - if isinstance(expect_to_contain, tuple): - expect_to_contain = expect_to_contain[index] - result = self.get_success( store.search_msgs([self.room_id], query, ["content.body"]) ) @@ -343,9 +334,6 @@ class MessageSearchTest(HomeserverTestCase): # Run them again versus search_rooms for query, expect_to_contain in cases: - if isinstance(expect_to_contain, tuple): - expect_to_contain = expect_to_contain[index] - result = self.get_success( store.search_rooms([self.room_id], query, ["content.body"], 10) ) @@ -366,38 +354,15 @@ class MessageSearchTest(HomeserverTestCase): """ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. This test is skipped unless the postgres instance supports websearch_to_tsquery. - """ - - store = self.hs.get_datastores().main - if not isinstance(store.database_engine, PostgresEngine): - raise SkipTest("Test only applies when postgres is used as the database") - - if store.database_engine.tsquery_func != "websearch_to_tsquery": - raise SkipTest( - "Test only applies when postgres supporting websearch_to_tsquery is used as the database" - ) - self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) - - def test_postgres_non_web_search_for_phrase(self): - """ - Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't - supported by the current postgres version. + See https://www.postgresql.org/docs/current/textsearch-controls.html """ store = self.hs.get_datastores().main if not isinstance(store.database_engine, PostgresEngine): raise SkipTest("Test only applies when postgres is used as the database") - # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. - with patch( - "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", - new_callable=PropertyMock, - ) as supports_websearch_to_tsquery: - supports_websearch_to_tsquery.return_value = "plainto_tsquery" - self._check_test_cases( - store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 - ) + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) def test_sqlite_search(self): """ @@ -407,4 +372,4 @@ class MessageSearchTest(HomeserverTestCase): if not isinstance(store.database_engine, Sqlite3Engine): raise SkipTest("Test only applies when sqlite is used as the database") - self._check_test_cases(store, self.COMMON_CASES, index=0) + self._check_test_cases(store, self.COMMON_CASES) -- cgit 1.5.1 From d10a85ec9eac6f31aa82a5f07d74e5914b18b320 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 10 Nov 2022 12:17:46 +0000 Subject: Quieter logging for stateres failure at missing prev events (#14346) --- changelog.d/14346.misc | 1 + synapse/handlers/federation_event.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14346.misc (limited to 'synapse') diff --git a/changelog.d/14346.misc b/changelog.d/14346.misc new file mode 100644 index 0000000000..9833b0733a --- /dev/null +++ b/changelog.d/14346.misc @@ -0,0 +1 @@ +Concisely log a failure to resolve state due to missing `prev_events`. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9ca5df7c78..f7223b03c3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1065,10 +1065,9 @@ class FederationEventHandler: state_res_store=StateResolutionStore(self._store), ) - except Exception: + except Exception as e: logger.warning( - "Error attempting to resolve state at missing prev_events", - exc_info=True, + "Error attempting to resolve state at missing prev_events: %s", e ) raise FederationError( "ERROR", -- cgit 1.5.1 From b2c2b030798d0e74d3bf1afb4726465b53620638 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 10 Nov 2022 19:02:27 +0000 Subject: Fix PostgreSQL sometimes using table scans for `event_search` (#14409) PostgreSQL may underestimate the number of distinct `room_id`s in `event_search`, which can cause it to use table scans for queries for multiple rooms. Fix this by setting `n_distinct` on the column. Resolves #14402. Signed-off-by: Sean Quah --- changelog.d/14409.bugfix | 1 + .../11event_search_room_id_n_distinct.sql.postgres | 33 ++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 changelog.d/14409.bugfix create mode 100644 synapse/storage/schema/main/delta/73/11event_search_room_id_n_distinct.sql.postgres (limited to 'synapse') diff --git a/changelog.d/14409.bugfix b/changelog.d/14409.bugfix new file mode 100644 index 0000000000..f720700653 --- /dev/null +++ b/changelog.d/14409.bugfix @@ -0,0 +1 @@ +Fix PostgreSQL sometimes using table scans for queries against the `event_search` table, taking a long time and a large amount of IO. diff --git a/synapse/storage/schema/main/delta/73/11event_search_room_id_n_distinct.sql.postgres b/synapse/storage/schema/main/delta/73/11event_search_room_id_n_distinct.sql.postgres new file mode 100644 index 0000000000..93cdaefca1 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/11event_search_room_id_n_distinct.sql.postgres @@ -0,0 +1,33 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- By default the postgres statistics collector massively underestimates the +-- number of distinct rooms in `event_search`, which can cause postgres to use +-- table scans for queries for multiple rooms. +-- +-- To work around this we can manually tell postgres the number of distinct rooms +-- by setting `n_distinct` (a negative value here is the number of distinct values +-- divided by the number of rows, so -0.01 means on average there are 100 rows per +-- distinct value). We don't need a particularly accurate number here, as a) we just +-- want it to always use index scans and b) our estimate is going to be better than the +-- one made by the statistics collector. + +ALTER TABLE event_search ALTER COLUMN room_id SET (n_distinct = -0.01); + +-- Ideally we'd do an `ANALYZE event_search (room_id)` here so that +-- the above gets picked up immediately, but that can take a bit of time so we +-- rely on the autovacuum eventually getting run and doing that in the +-- background for us. -- cgit 1.5.1 From 13ca8bb2fc05d338ccf62e6f8d1cbf5021d935ba Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Nov 2022 15:33:34 -0500 Subject: Remove duplicated code to evict entries. (#14410) This code was factored out to a method, but also left in-place. Calling this twice in a row makes no sense: the first call will reduce the size appropriately, but the loop will immediately exit since the cache size was already reduced. --- changelog.d/14410.misc | 1 + synapse/util/caches/stream_change_cache.py | 11 ++--------- 2 files changed, 3 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14410.misc (limited to 'synapse') diff --git a/changelog.d/14410.misc b/changelog.d/14410.misc new file mode 100644 index 0000000000..f085a8bfb2 --- /dev/null +++ b/changelog.d/14410.misc @@ -0,0 +1 @@ +Remove unreachable code. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 330709b8b7..666f4b6895 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -72,7 +72,7 @@ class StreamChangeCache: items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ new_size = math.floor(self._original_max_size * factor) if new_size != self._max_size: @@ -188,14 +188,8 @@ class StreamChangeCache: self._entity_to_key[entity] = stream_pos self._evict() - # if the cache is too big, remove entries - while len(self._cache) > self._max_size: - k, r = self._cache.popitem(0) - self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) - for entity in r: - del self._entity_to_key[entity] - def _evict(self) -> None: + # if the cache is too big, remove entries while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) @@ -203,7 +197,6 @@ class StreamChangeCache: self._entity_to_key.pop(entity, None) def get_max_pos_of_last_change(self, entity: EntityType) -> int: - """Returns an upper bound of the stream id of the last change to an entity. """ -- cgit 1.5.1 From 3a4f80f8c6f39c5549c56c044e10b35064d8d22f Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 11 Nov 2022 10:51:49 +0000 Subject: Merge/remove `Slaved*` stores into `WorkerStores` (#14375) --- changelog.d/14375.misc | 1 + synapse/app/admin_cmd.py | 36 ++++++++--- synapse/app/generic_worker.py | 44 ++++++++++---- synapse/replication/slave/storage/devices.py | 79 ------------------------ synapse/replication/slave/storage/events.py | 79 ------------------------ synapse/replication/slave/storage/filtering.py | 35 ----------- synapse/replication/slave/storage/keys.py | 20 ------ synapse/replication/slave/storage/push_rule.py | 35 ----------- synapse/replication/slave/storage/pushers.py | 47 -------------- synapse/storage/databases/main/__init__.py | 35 ----------- synapse/storage/databases/main/devices.py | 81 ++++++++++++++++++++++--- synapse/storage/databases/main/events_worker.py | 16 +++++ synapse/storage/databases/main/filtering.py | 4 +- synapse/storage/databases/main/push_rule.py | 19 ++++-- synapse/storage/databases/main/pusher.py | 41 +++++++++++-- synapse/storage/databases/main/stream.py | 1 + tests/replication/slave/storage/test_events.py | 6 +- 17 files changed, 202 insertions(+), 377 deletions(-) create mode 100644 changelog.d/14375.misc delete mode 100644 synapse/replication/slave/storage/devices.py delete mode 100644 synapse/replication/slave/storage/events.py delete mode 100644 synapse/replication/slave/storage/filtering.py delete mode 100644 synapse/replication/slave/storage/keys.py delete mode 100644 synapse/replication/slave/storage/push_rule.py delete mode 100644 synapse/replication/slave/storage/pushers.py (limited to 'synapse') diff --git a/changelog.d/14375.misc b/changelog.d/14375.misc new file mode 100644 index 0000000000..d0369b9b8c --- /dev/null +++ b/changelog.d/14375.misc @@ -0,0 +1 @@ +Cleanup old worker datastore classes. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 3c8c00ea5b..165d1c5db0 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -28,10 +28,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.events import EventBase from synapse.handlers.admin import ExfiltrationWriter -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -40,10 +36,24 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceWorkerStore, ) from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.types import StateMap from synapse.util import SYNAPSE_VERSION from synapse.util.logcontext import LoggingContext @@ -52,17 +62,25 @@ logger = logging.getLogger("synapse.app.admin_cmd") class AdminCmdSlavedStore( - SlavedFilteringStore, - SlavedPushRuleStore, - SlavedEventStore, - SlavedDeviceStore, + FilteringWorkerStore, + DeviceWorkerStore, TagsWorkerStore, DeviceInboxWorkerStore, AccountDataWorkerStore, + PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, - RegistrationWorkerStore, + RoomMemberWorkerStore, + RelationsWorkerStore, + EventFederationWorkerStore, + EventPushActionsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, ReceiptsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, + RegistrationWorkerStore, RoomWorkerStore, ): def __init__( diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cb5892f041..51446b49cd 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -48,12 +48,6 @@ from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client import ( account_data, @@ -101,8 +95,16 @@ from synapse.storage.databases.main.appservice import ( from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.directory import DirectoryWorkerStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.keys import KeyStore from synapse.storage.databases.main.lock import LockStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore @@ -111,17 +113,25 @@ from synapse.storage.databases.main.monthly_active_users import ( ) from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore +from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.room_batch import RoomBatchStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION from synapse.util.httpresourcetree import create_resource_tree @@ -232,26 +242,36 @@ class GenericWorkerSlavedStore( EndToEndRoomKeyStore, PresenceStore, DeviceInboxWorkerStore, - SlavedDeviceStore, - SlavedPushRuleStore, + DeviceWorkerStore, TagsWorkerStore, AccountDataWorkerStore, - SlavedPusherStore, CensorEventsStore, ClientIpWorkerStore, - SlavedEventStore, - SlavedKeyStore, + # KeyStore isn't really safe to use from a worker, but for now we do so and hope that + # the races it creates aren't too bad. + KeyStore, RoomWorkerStore, RoomBatchStore, DirectoryWorkerStore, + PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ProfileWorkerStore, - SlavedFilteringStore, + FilteringWorkerStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, ServerMetricsStore, + PusherWorkerStore, + RoomMemberWorkerStore, + RelationsWorkerStore, + EventFederationWorkerStore, + EventPushActionsWorkerStore, + StateGroupWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, ReceiptsWorkerStore, + StreamWorkerStore, + EventsWorkerStore, RegistrationWorkerStore, SearchStore, TransactionWorkerStore, diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py deleted file mode 100644 index 6fcade510a..0000000000 --- a/synapse/replication/slave/storage/devices.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Iterable - -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.devices import DeviceWorkerStore - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedDeviceStore(DeviceWorkerStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - self.hs = hs - - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) - self._invalidate_caches_for_devices(token, rows) - elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) - for row in rows: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) - - def _invalidate_caches_for_devices( - self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] - ) -> None: - for row in rows: - # The entities are either user IDs (starting with '@') whose devices - # have changed, or remote servers that we need to tell about - # changes. - if row.entity.startswith("@"): - self._device_list_stream_cache.entity_has_changed(row.entity, token) - self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate((row.entity,)) - self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - - else: - self._device_list_federation_stream_cache.entity_has_changed( - row.entity, token - ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py deleted file mode 100644 index fe47778cb1..0000000000 --- a/synapse/replication/slave/storage/events.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING - -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.event_federation import EventFederationWorkerStore -from synapse.storage.databases.main.event_push_actions import ( - EventPushActionsWorkerStore, -) -from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.databases.main.relations import RelationsWorkerStore -from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.databases.main.signatures import SignatureWorkerStore -from synapse.storage.databases.main.state import StateGroupWorkerStore -from synapse.storage.databases.main.stream import StreamWorkerStore -from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# So, um, we want to borrow a load of functions intended for reading from -# a DataStore, but we don't want to take functions that either write to the -# DataStore or are cached and don't have cache invalidation logic. -# -# Rather than write duplicate versions of those functions, or lift them to -# a common base class, we going to grab the underlying __func__ object from -# the method descriptor on the DataStore and chuck them into our class. - - -class SlavedEventStore( - EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, - StateGroupWorkerStore, - SignatureWorkerStore, - EventsWorkerStore, - UserErasureWorkerStore, - RelationsWorkerStore, -): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py deleted file mode 100644 index c52679cd60..0000000000 --- a/synapse/replication/slave/storage/filtering.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.filtering import FilteringStore - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedFilteringStore(SQLBaseStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - # Filters are immutable so this cache doesn't need to be expired - get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py deleted file mode 100644 index a00b38c512..0000000000 --- a/synapse/replication/slave/storage/keys.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.databases.main.keys import KeyStore - -# KeyStore isn't really safe to use from a worker, but for now we do so and hope that -# the races it creates aren't too bad. - -SlavedKeyStore = KeyStore diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py deleted file mode 100644 index 5e65eaf1e0..0000000000 --- a/synapse/replication/slave/storage/push_rule.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterable - -from synapse.replication.tcp.streams import PushRulesStream -from synapse.storage.databases.main.push_rule import PushRulesWorkerStore - -from .events import SlavedEventStore - - -class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_max_push_rules_stream_id(self) -> int: - return self._push_rules_stream_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(instance_name, token) - for row in rows: - self.get_push_rules_for_user.invalidate((row.user_id,)) - self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py deleted file mode 100644 index 44ed20e424..0000000000 --- a/synapse/replication/slave/storage/pushers.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Any, Iterable - -from synapse.replication.tcp.streams import PushersStream -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.pusher import PusherWorkerStore - -from ._slaved_id_tracker import SlavedIdTracker - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedPusherStore(PusherWorkerStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - self._pushers_id_gen = SlavedIdTracker( # type: ignore - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) - - def get_pushers_stream_token(self) -> int: - return self._pushers_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index cfaedf5e0c..0e47592be3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -26,9 +26,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id -from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -138,41 +136,8 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._device_list_id_gen = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - super().__init__(database, db_conn, hs) - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - def get_device_stream_token(self) -> int: - # TODO: shouldn't this be moved to `DeviceWorkerStore`? - return self._device_list_id_gen.get_current_token() - async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 979dd4e17e..aa58c2adc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import ( TYPE_CHECKING, @@ -39,6 +38,8 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -49,6 +50,11 @@ from synapse.storage.database import ( from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -80,9 +86,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) + if hs.config.worker.worker_app is None: + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) + else: + self._device_list_id_gen = SlavedIdTracker( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) + # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). - device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined] + device_list_max = self._device_list_id_gen.get_current_token() device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( db_conn, "device_lists_stream", @@ -136,6 +165,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + self._invalidate_caches_for_devices(token, rows) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + for row in rows: + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) + + def _invalidate_caches_for_devices( + self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] + ) -> None: + for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) + + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) + + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -677,11 +739,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): }, ) - @abc.abstractmethod - def get_device_stream_token(self) -> int: - """Get the current stream id from the _device_list_id_gen""" - ... - @trace @cancellable async def get_user_devices_from_cache( @@ -1481,6 +1538,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + # Because we have write access, this will be a StreamIdGenerator + # (see DeviceWorkerStore.__init__) + _device_list_id_gen: AbstractStreamIdGenerator + def __init__( self, database: DatabasePool, @@ -1805,7 +1866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context, ) - async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -2044,7 +2105,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [], ) - async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: return await self.db_pool.runInteraction( "add_device_list_outbound_pokes", add_device_list_outbound_pokes_txn, @@ -2058,7 +2119,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates during partial joins. """ - async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.simple_upsert( table="device_lists_remote_pending", keyvalues={ diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 69fea452ad..a79091952a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import AsyncLruCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -233,6 +234,21 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cb9ee08fa8..12f3b601f1 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -24,7 +24,7 @@ from synapse.types import JsonDict from synapse.util.caches.descriptors import cached -class FilteringStore(SQLBaseStore): +class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( self, user_localpart: str, filter_id: Union[int, str] @@ -46,6 +46,8 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) + +class FilteringStore(FilteringWorkerStore): async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index b6c15f29f8..8ae10f6127 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, + Iterable, List, Mapping, Optional, @@ -31,6 +31,7 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -90,8 +91,6 @@ def _load_rules( return filtered_rules -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, PusherWorkerStore, @@ -99,7 +98,6 @@ class PushRulesWorkerStore( ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore, - metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. @@ -136,14 +134,23 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - @abc.abstractmethod def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: int """ - raise NotImplementedError() + return self._push_rules_stream_id_gen.get_current_token() + + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 01206950a9..4a01562d45 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,13 +27,19 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -52,9 +58,21 @@ class PusherWorkerStore(SQLBaseStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) + + if hs.config.worker.worker_app is None: + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) + else: + self._pushers_id_gen = SlavedIdTracker( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", @@ -96,6 +114,16 @@ class PusherWorkerStore(SQLBaseStore): yield PusherConfig(**r) + def get_pushers_stream_token(self) -> int: + return self._pushers_id_gen.get_current_token() + + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == PushersStream.NAME: + self._pushers_id_gen.advance(instance_name, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) + async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str ) -> Iterator[PusherConfig]: @@ -545,8 +573,9 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): - def get_pushers_stream_token(self) -> int: - return self._pushers_id_gen.get_current_token() + # Because we have write access, this will be a StreamIdGenerator + # (see PusherWorkerStore.__init__) + _pushers_id_gen: AbstractStreamIdGenerator async def add_pusher( self, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 09ce855aa8..cc27ec3804 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -415,6 +415,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_room_max_stream_ordering(self) -> int: """Get the stream_ordering of regular events that we have committed up to diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d42e36cdf1..96f3880923 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -21,11 +21,11 @@ from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource -from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.databases.main.event_push_actions import ( NotifCounts, RoomNotifCounts, ) +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -58,9 +58,9 @@ def patch__eq__(cls): return unpatch -class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): +class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = SlavedEventStore + STORE_TYPE = EventsWorkerStore def setUp(self): # Patch up the equality operator for events so that we can check -- cgit 1.5.1 From a3623af74e0af0d2f6cbd37b47dc54a1acd314d5 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 11 Nov 2022 19:38:17 +0400 Subject: Add an Admin API endpoint for looking up users based on 3PID (#14405) --- changelog.d/14405.feature | 1 + docs/admin_api/user_admin_api.md | 39 ++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 25 +++++++++ tests/rest/admin/test_user.py | 107 ++++++++++++++++++++++++++++++++++----- 5 files changed, 161 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14405.feature (limited to 'synapse') diff --git a/changelog.d/14405.feature b/changelog.d/14405.feature new file mode 100644 index 0000000000..d3ba89b597 --- /dev/null +++ b/changelog.d/14405.feature @@ -0,0 +1 @@ +Add an [Admin API](https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoint for user lookup based on third-party ID (3PID). Contributed by @ashfame. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c95d6c9b05..880bef4194 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1197,3 +1197,42 @@ Returns a `404` HTTP status code if no user was found, with a response body like ``` _Added in Synapse 1.68.0._ + + +### Find a user based on their Third Party ID (ThreePID or 3PID) + +The API is: + +``` +GET /_synapse/admin/v1/threepid/$medium/users/$address +``` + +When a user matched the given address for the given medium, an HTTP code `200` with a response body like the following is returned: + +```json +{ + "user_id": "@hello:example.org" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `medium` - Kind of third-party ID, either `email` or `msisdn`. +- `address` - Value of the third-party ID. + +The `address` may have characters that are not URL-safe, so it is advised to URL-encode those parameters. + +**Errors** + +Returns a `404` HTTP status code if no user was found, with a response body like this: + +```json +{ + "errcode":"M_NOT_FOUND", + "error":"User not found" +} +``` + +_Added in Synapse 1.72.0._ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 885669f9c7..c62ea22116 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -81,6 +81,7 @@ from synapse.rest.admin.users import ( ShadowBanRestServlet, UserAdminServlet, UserByExternalId, + UserByThreePid, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -277,6 +278,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) UserByExternalId(hs).register(http_server) + UserByThreePid(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 15ac2059aa..1951b8a9f2 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1224,3 +1224,28 @@ class UserByExternalId(RestServlet): raise NotFoundError("User not found") return HTTPStatus.OK, {"user_id": user_id} + + +class UserByThreePid(RestServlet): + """Find a user based on 3PID of a particular medium""" + + PATTERNS = admin_patterns("/threepid/(?P[^/]*)/users/(?P
[^/]*)") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, + request: SynapseRequest, + medium: str, + address: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + user_id = await self._store.get_user_id_by_threepid(medium, address) + + if user_id is None: + raise NotFoundError("User not found") + + return HTTPStatus.OK, {"user_id": user_id} diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 63410ffdf1..e8c9457794 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -41,14 +41,12 @@ from tests.unittest import override_config class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, profile.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = "/_synapse/admin/v1/register" self.registration_handler = Mock() @@ -446,7 +444,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class UsersListTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1108,7 +1105,6 @@ class UserDevicesTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1382,7 +1378,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): class UserRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2803,7 +2798,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): class UserMembershipRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2960,7 +2954,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): class PushersRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3089,7 +3082,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): class UserMediaRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3881,7 +3873,6 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ], ) class WhoisRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3961,7 +3952,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): class ShadowBanRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4042,7 +4032,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): class RateLimitTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4268,7 +4257,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): class AccountDataTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4358,7 +4346,6 @@ class AccountDataTestCase(unittest.HomeserverTestCase): class UsersByExternalIdTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4442,3 +4429,97 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase): {"user_id": self.other_user}, channel.json_body, ) + + +class UsersByThreePidTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.get_success( + self.store.user_add_threepid( + self.other_user, "email", "user@email.com", 1, 1 + ) + ) + self.get_success( + self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1) + ) + + def test_no_auth(self) -> None: + """Try to look up a user without authentication.""" + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + ) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_medium_does_not_exist(self) -> None: + """Tests that both a lookup for a medium that does not exist and a user that + doesn't exist with that third party ID returns a 404""" + # test for unknown medium + url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + # test for unknown user with a known medium + url = "/_synapse/admin/v1/threepid/email/users/unknown" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_success(self) -> None: + """Tests a successful medium + address lookup""" + # test for email medium with encoded value of user@email.com + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) + + # test for msidn medium with encoded value of +1-12345678 + url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) -- cgit 1.5.1 From fb66fae84b165e7bd132bc7cbc5732485ceee827 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 14 Nov 2022 08:13:11 -0500 Subject: Clean-up events persistance code (#14411) By removing unused variables and making some arguments required which are always provided. --- changelog.d/14411.misc | 1 + synapse/storage/controllers/persist_events.py | 2 -- synapse/storage/databases/main/events.py | 9 +++------ 3 files changed, 4 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14411.misc (limited to 'synapse') diff --git a/changelog.d/14411.misc b/changelog.d/14411.misc new file mode 100644 index 0000000000..f5cca5c833 --- /dev/null +++ b/changelog.d/14411.misc @@ -0,0 +1 @@ +Clean-up event persistence code. diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 06e71a8053..48976dc570 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -716,8 +716,6 @@ class EventsPersistenceStorageController: ) if not is_still_joined: logger.info("Server no longer in room %s", room_id) - latest_event_ids = set() - current_state = {} delta.no_longer_in_room = True state_delta_for_room[room_id] = delta diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 00880bb37d..c4acff5be6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -355,9 +355,9 @@ class PersistEventsStore: txn: LoggingTransaction, *, events_and_contexts: List[Tuple[EventBase, EventContext]], - inhibit_local_membership_updates: bool = False, - state_delta_for_room: Optional[Dict[str, DeltaState]] = None, - new_forward_extremities: Optional[Dict[str, Set[str]]] = None, + inhibit_local_membership_updates: bool, + state_delta_for_room: Dict[str, DeltaState], + new_forward_extremities: Dict[str, Set[str]], ) -> None: """Insert some number of room events into the necessary database tables. @@ -384,9 +384,6 @@ class PersistEventsStore: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - state_delta_for_room = state_delta_for_room or {} - new_forward_extremities = new_forward_extremities or {} - all_events_and_contexts = events_and_contexts min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering -- cgit 1.5.1 From 2cc592584ae9f225216b7663e9144ac6f565b757 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 14 Nov 2022 13:46:29 +0000 Subject: Remove unused type-ignores (#14433) * Remove unused type-ignores Oversights in #14427 and #14429. * Changelog --- changelog.d/14433.misc | 1 + scripts-dev/release.py | 4 +--- synapse/streams/events.py | 9 ++++++--- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14433.misc (limited to 'synapse') diff --git a/changelog.d/14433.misc b/changelog.d/14433.misc new file mode 100644 index 0000000000..08a350b13b --- /dev/null +++ b/changelog.d/14433.misc @@ -0,0 +1 @@ +Fix mypy errors introduced by bumping the locked version of `attrs` and `gitpython`. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index c82c58c54b..bf47b6c713 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -219,9 +219,7 @@ def _prepare() -> None: update_branch(repo) # Create the new release branch - # Type ignore will no longer be needed after GitPython 3.1.28. - # See https://github.com/gitpython-developers/GitPython/pull/1419 - repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type] + repo.create_head(release_branch_name, commit=base_branch) # Special-case SyTest: we don't actually prepare any files so we may # as well push it now (and only when we create a release branch; diff --git a/synapse/streams/events.py b/synapse/streams/events.py index bcd840bd88..f331e1af16 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -45,9 +45,12 @@ class _EventSourcesInner: class EventSources: def __init__(self, hs: "HomeServer"): self.sources = _EventSourcesInner( - # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since - # all the attributes of `_EventSourcesInner` are annotated. - *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] + # mypy previously warned that attribute.type is `Optional`, but we know it's + # never `None` here since all the attributes of `_EventSourcesInner` are + # annotated. + # As of the stubs in attrs 22.1.0, `attr.fields()` now returns Any, + # so the call to `attribute.type` is not checked. + *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) ) self.store = hs.get_datastores().main -- cgit 1.5.1 From 36097e88c4da51fce6556a58c49bd675f4cf20ab Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 14 Nov 2022 17:31:36 +0000 Subject: Remove slaved id tracker (#14376) This matches the multi instance writer ID generator class which can both handle advancing the current token over replication and by calling the database. --- changelog.d/14376.misc | 1 + synapse/replication/slave/__init__.py | 13 ------ synapse/replication/slave/storage/__init__.py | 13 ------ .../slave/storage/_slaved_id_tracker.py | 50 ---------------------- synapse/storage/databases/main/account_data.py | 30 +++++-------- synapse/storage/databases/main/devices.py | 36 ++++++---------- synapse/storage/databases/main/events_worker.py | 35 ++++++--------- synapse/storage/databases/main/push_rule.py | 17 ++++---- synapse/storage/databases/main/pusher.py | 24 ++++------- synapse/storage/databases/main/receipts.py | 18 ++++---- synapse/storage/util/id_generators.py | 13 ++++-- 11 files changed, 74 insertions(+), 176 deletions(-) create mode 100644 changelog.d/14376.misc delete mode 100644 synapse/replication/slave/__init__.py delete mode 100644 synapse/replication/slave/storage/__init__.py delete mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'synapse') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index aa58c2adc3..3e5c16b15b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a79091952a..7a003ab88f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dc6989527e..64519587f8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..1af0af1266 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step -- cgit 1.5.1 From 634359b083eae319d7f065114851590431b7c7fb Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 15 Nov 2022 10:43:17 +0000 Subject: Update docstring to clarify that `get_partial_state_events_batch` does not just give you completely arbitrary partial-state events. (#14417) --- changelog.d/14417.misc | 1 + synapse/storage/databases/main/events_worker.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14417.misc (limited to 'synapse') diff --git a/changelog.d/14417.misc b/changelog.d/14417.misc new file mode 100644 index 0000000000..7527fe97c2 --- /dev/null +++ b/changelog.d/14417.misc @@ -0,0 +1 @@ +Update docstring to clarify that `get_partial_state_events_batch` does not just give you completely arbitrary partial-state events. \ No newline at end of file diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7a003ab88f..296e50d661 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -2228,7 +2228,15 @@ class EventsWorkerStore(SQLBaseStore): return result is not None async def get_partial_state_events_batch(self, room_id: str) -> List[str]: - """Get a list of events in the given room that have partial state""" + """ + Get a list of events in the given room that: + - have partial state; and + - are ready to be resynced (because they have no prev_events that are + partial-stated) + + See the docstring on `_get_partial_state_events_batch_txn` for more + information. + """ return await self.db_pool.runInteraction( "get_partial_state_events_batch", self._get_partial_state_events_batch_txn, -- cgit 1.5.1 From b5ab2c428a1c5edd634ff084019811e5f6b963d8 Mon Sep 17 00:00:00 2001 From: Tuomas Ojamies Date: Tue, 15 Nov 2022 13:55:00 +0100 Subject: Support using SSL on worker endpoints. (#14128) * Fix missing SSL support in worker endpoints. * Add changelog * SSL for Replication endpoint * Remove unit test change * Refactor listener creation to reduce duplicated code * Fix the logger message * Update synapse/app/_base.py Co-authored-by: Patrick Cloke * Update synapse/app/_base.py Co-authored-by: Patrick Cloke * Update synapse/app/_base.py Co-authored-by: Patrick Cloke * Add config documentation for new TLS option Co-authored-by: Tuomas Ojamies Co-authored-by: Patrick Cloke Co-authored-by: Olivier Wilkinson (reivilibre) --- changelog.d/14128.misc | 1 + docs/usage/configuration/config_documentation.md | 20 +++++++++ synapse/app/_base.py | 53 +++++++++++++++++++++++- synapse/app/generic_worker.py | 28 ++++--------- synapse/app/homeserver.py | 34 ++------------- synapse/config/workers.py | 7 ++++ synapse/replication/http/_base.py | 10 ++++- 7 files changed, 100 insertions(+), 53 deletions(-) create mode 100644 changelog.d/14128.misc (limited to 'synapse') diff --git a/changelog.d/14128.misc b/changelog.d/14128.misc new file mode 100644 index 0000000000..29168ef955 --- /dev/null +++ b/changelog.d/14128.misc @@ -0,0 +1 @@ +Add TLS support for generic worker endpoints. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 9a6bd08d01..f5937dd902 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3893,6 +3893,26 @@ Example configuration: worker_replication_http_port: 9093 ``` --- +### `worker_replication_http_tls` + +Whether TLS should be used for talking to the HTTP replication port on the main +Synapse process. +The main Synapse process defines this with the `tls` option on its [listener](#listeners) that +has the `replication` resource enabled. + +**Please note:** by default, it is not safe to expose replication ports to the +public Internet, even with TLS enabled. +See [`worker_replication_secret`](#worker_replication_secret). + +Defaults to `false`. + +*Added in Synapse 1.72.0.* + +Example configuration: +```yaml +worker_replication_http_tls: true +``` +--- ### `worker_listeners` A worker can handle HTTP requests. To do so, a `worker_listeners` option diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a683ebf4cb..8f5b1a20f5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -47,6 +47,7 @@ from twisted.internet.tcp import Port from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool +from twisted.web.resource import Resource import synapse.util.caches from synapse.api.constants import MAX_PDU_SIZE @@ -55,12 +56,13 @@ from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config import ConfigError from synapse.config._base import format_config_error from synapse.config.homeserver import HomeServerConfig -from synapse.config.server import ManholeConfig +from synapse.config.server import ListenerConfig, ManholeConfig from synapse.crypto import context_factory from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.http.site import SynapseSite from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import init_tracer from synapse.metrics import install_gc_manager, register_threadpool @@ -357,6 +359,55 @@ def listen_tcp( return r # type: ignore[return-value] +def listen_http( + listener_config: ListenerConfig, + root_resource: Resource, + version_string: str, + max_request_body_size: int, + context_factory: IOpenSSLContextFactory, + reactor: IReactorSSL = reactor, +) -> List[Port]: + port = listener_config.port + bind_addresses = listener_config.bind_addresses + tls = listener_config.tls + + assert listener_config.http_options is not None + + site_tag = listener_config.http_options.tag + if site_tag is None: + site_tag = str(port) + + site = SynapseSite( + "synapse.access.%s.%s" % ("https" if tls else "http", site_tag), + site_tag, + listener_config, + root_resource, + version_string, + max_request_body_size=max_request_body_size, + reactor=reactor, + ) + if tls: + # refresh_certificate should have been called before this. + assert context_factory is not None + ports = listen_ssl( + bind_addresses, + port, + site, + context_factory, + reactor=reactor, + ) + logger.info("Synapse now listening on TCP port %d (TLS)", port) + else: + ports = listen_tcp( + bind_addresses, + port, + site, + reactor=reactor, + ) + logger.info("Synapse now listening on TCP port %d", port) + return ports + + def listen_ssl( bind_addresses: Collection[str], port: int, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 51446b49cd..1d9aef45c2 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -44,7 +44,7 @@ from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseRequest, SynapseSite +from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -288,15 +288,9 @@ class GenericWorkerServer(HomeServer): DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore def _listen_http(self, listener_config: ListenerConfig) -> None: - port = listener_config.port - bind_addresses = listener_config.bind_addresses assert listener_config.http_options is not None - site_tag = listener_config.http_options.tag - if site_tag is None: - site_tag = str(port) - # We always include a health resource. resources: Dict[str, Resource] = {"/health": HealthResource()} @@ -395,23 +389,15 @@ class GenericWorkerServer(HomeServer): root_resource = create_resource_tree(resources, OptionsResource()) - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - max_request_body_size=max_request_body_size(self.config), - reactor=self.get_reactor(), - ), + _base.listen_http( + listener_config, + root_resource, + self.version_string, + max_request_body_size(self.config), + self.tls_server_context_factory, reactor=self.get_reactor(), ) - logger.info("Synapse worker now listening on port %d", port) - def start_listening(self) -> None: for listener in self.config.worker.worker_listeners: if listener.type == "http": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index de3f08876f..4f4fee4782 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -37,8 +37,7 @@ from synapse.api.urls import ( from synapse.app import _base from synapse.app._base import ( handle_startup_exception, - listen_ssl, - listen_tcp, + listen_http, max_request_body_size, redirect_stdio_to_logs, register_start, @@ -53,7 +52,6 @@ from synapse.http.server import ( RootOptionsRedirectResource, StaticResource, ) -from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -83,8 +81,6 @@ class SynapseHomeServer(HomeServer): self, config: HomeServerConfig, listener_config: ListenerConfig ) -> Iterable[Port]: port = listener_config.port - bind_addresses = listener_config.bind_addresses - tls = listener_config.tls # Must exist since this is an HTTP listener. assert listener_config.http_options is not None site_tag = listener_config.http_options.tag @@ -140,37 +136,15 @@ class SynapseHomeServer(HomeServer): else: root_resource = OptionsResource() - site = SynapseSite( - "synapse.access.%s.%s" % ("https" if tls else "http", site_tag), - site_tag, + ports = listen_http( listener_config, create_resource_tree(resources, root_resource), self.version_string, - max_request_body_size=max_request_body_size(self.config), + max_request_body_size(self.config), + self.tls_server_context_factory, reactor=self.get_reactor(), ) - if tls: - # refresh_certificate should have been called before this. - assert self.tls_server_context_factory is not None - ports = listen_ssl( - bind_addresses, - port, - site, - self.tls_server_context_factory, - reactor=self.get_reactor(), - ) - logger.info("Synapse now listening on TCP port %d (TLS)", port) - - else: - ports = listen_tcp( - bind_addresses, - port, - site, - reactor=self.get_reactor(), - ) - logger.info("Synapse now listening on TCP port %d", port) - return ports def _configure_named_resource( diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 0fb725dd8f..88b3168cbc 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -67,6 +67,7 @@ class InstanceLocationConfig: host: str port: int + tls: bool = False @attr.s @@ -149,6 +150,12 @@ class WorkerConfig(Config): # The port on the main synapse for HTTP replication endpoint self.worker_replication_http_port = config.get("worker_replication_http_port") + # The tls mode on the main synapse for HTTP replication endpoint. + # For backward compatibility this defaults to False. + self.worker_replication_http_tls = config.get( + "worker_replication_http_tls", False + ) + # The shared secret used for authentication when connecting to the main synapse. self.worker_replication_secret = config.get("worker_replication_secret", None) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index acb0bd18f7..5e661f8c73 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -184,8 +184,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): client = hs.get_simple_http_client() local_instance_name = hs.get_instance_name() + # The value of these option should match the replication listener settings master_host = hs.config.worker.worker_replication_host master_port = hs.config.worker.worker_replication_http_port + master_tls = hs.config.worker.worker_replication_http_tls instance_map = hs.config.worker.instance_map @@ -205,9 +207,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if instance_name == "master": host = master_host port = master_port + tls = master_tls elif instance_name in instance_map: host = instance_map[instance_name].host port = instance_map[instance_name].port + tls = instance_map[instance_name].tls else: raise Exception( "Instance %r not in 'instance_map' config" % (instance_name,) @@ -238,7 +242,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): "Unknown METHOD on %s replication endpoint" % (cls.NAME,) ) - uri = "http://%s:%s/_synapse/replication/%s/%s" % ( + # Here the protocol is hard coded to be http by default or https in case the replication + # port is set to have tls true. + scheme = "https" if tls else "http" + uri = "%s://%s:%s/_synapse/replication/%s/%s" % ( + scheme, host, port, cls.NAME, -- cgit 1.5.1 From 63cc56affa3872443fffcac655413a8d9ffabfe4 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 15 Nov 2022 16:29:30 +0100 Subject: Send content rules with pattern_type to clients (#14356) --- changelog.d/14356.bugfix | 1 + synapse/push/clientformat.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14356.bugfix (limited to 'synapse') diff --git a/changelog.d/14356.bugfix b/changelog.d/14356.bugfix new file mode 100644 index 0000000000..288d58a540 --- /dev/null +++ b/changelog.d/14356.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.66 which would not send certain pushrules to clients. Contributed by Nico. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 7095ae83f9..622a1e35c5 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -44,6 +44,12 @@ def format_push_rules_for_user( rulearray.append(template_rule) + pattern_type = template_rule.pop("pattern_type", None) + if pattern_type == "user_id": + template_rule["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + template_rule["pattern"] = user.localpart + template_rule["enabled"] = enabled if "conditions" not in template_rule: @@ -93,10 +99,14 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]: if len(rule.conditions) != 1: return None thecond = rule.conditions[0] - if "pattern" not in thecond: - return None + templaterule = {"actions": rule.actions} - templaterule["pattern"] = thecond["pattern"] + if "pattern" in thecond: + templaterule["pattern"] = thecond["pattern"] + elif "pattern_type" in thecond: + templaterule["pattern_type"] = thecond["pattern_type"] + else: + return None else: # This should not be reached unless this function is not kept in sync # with PRIORITY_CLASS_INVERSE_MAP. -- cgit 1.5.1 From 258b5285b6b486526dffef9431c2ab063913f42b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 15 Nov 2022 16:36:43 +0000 Subject: Fix typechecking errors introduced in #14128 (#14455) * Fix typechecking errors introduced in #14128 * Changelog * Correct annotations so that context_factory works if you don't use TLS --- changelog.d/14455.misc | 1 + synapse/app/_base.py | 4 ++-- synapse/server.py | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14455.misc (limited to 'synapse') diff --git a/changelog.d/14455.misc b/changelog.d/14455.misc new file mode 100644 index 0000000000..29168ef955 --- /dev/null +++ b/changelog.d/14455.misc @@ -0,0 +1 @@ +Add TLS support for generic worker endpoints. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8f5b1a20f5..41d2732ef9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -364,8 +364,8 @@ def listen_http( root_resource: Resource, version_string: str, max_request_body_size: int, - context_factory: IOpenSSLContextFactory, - reactor: IReactorSSL = reactor, + context_factory: Optional[IOpenSSLContextFactory], + reactor: ISynapseReactor = reactor, ) -> List[Port]: port = listener_config.port bind_addresses = listener_config.bind_addresses diff --git a/synapse/server.py b/synapse/server.py index c4e025af22..f0a60d0056 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -221,8 +221,6 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastores() DATASTORE_CLASS = abc.abstractproperty() - tls_server_context_factory: Optional[IOpenSSLContextFactory] - def __init__( self, hostname: str, @@ -258,6 +256,9 @@ class HomeServer(metaclass=abc.ABCMeta): self._module_web_resources: Dict[str, Resource] = {} self._module_web_resources_consumed = False + # This attribute is set by the free function `refresh_certificate`. + self.tls_server_context_factory: Optional[IOpenSSLContextFactory] = None + def register_module_web_resource(self, path: str, resource: Resource) -> None: """Allows a module to register a web resource to be served at the given path. -- cgit 1.5.1 From 1eed795fc56d95df3968e37f3a4db92f24513e15 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 15 Nov 2022 17:35:19 +0000 Subject: Include heroes in partial join responses' state (#14442) * Pull out hero selection logic * Include heroes in partial join response's state * Changelog * Fixup trial test * Remove TODO --- changelog.d/14442.feature | 1 + synapse/federation/federation_server.py | 23 +++++++++++++++++---- synapse/handlers/sync.py | 20 +++---------------- synapse/storage/databases/main/roommember.py | 30 ++++++++++++++++++++++++++++ tests/federation/test_federation_server.py | 11 ++++++---- 5 files changed, 60 insertions(+), 25 deletions(-) create mode 100644 changelog.d/14442.feature (limited to 'synapse') diff --git a/changelog.d/14442.feature b/changelog.d/14442.feature new file mode 100644 index 0000000000..917e7edfb3 --- /dev/null +++ b/changelog.d/14442.feature @@ -0,0 +1 @@ +Faster joins: include heroes' membership events in the partial join response, for rooms without a name or canonical alias. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 59e351595b..bb20af6e91 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -74,6 +74,8 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.lock import Lock +from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary +from synapse.storage.roommember import MemberSummary from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results @@ -691,8 +693,9 @@ class FederationServer(FederationBase): state_event_ids: Collection[str] servers_in_room: Optional[Collection[str]] if caller_supports_partial_state: + summary = await self.store.get_room_summary(room_id) state_event_ids = _get_event_ids_for_partial_state_join( - event, prev_state_ids + event, prev_state_ids, summary ) servers_in_room = await self.state.get_hosts_in_room_at_events( room_id, event_ids=event.prev_event_ids() @@ -1495,6 +1498,7 @@ class FederationHandlerRegistry: def _get_event_ids_for_partial_state_join( join_event: EventBase, prev_state_ids: StateMap[str], + summary: Dict[str, MemberSummary], ) -> Collection[str]: """Calculate state to be retuned in a partial_state send_join @@ -1521,8 +1525,19 @@ def _get_event_ids_for_partial_state_join( if current_membership_event_id is not None: state_event_ids.add(current_membership_event_id) - # TODO: return a few more members: - # - those with invites - # - those that are kicked? / banned + name_id = prev_state_ids.get((EventTypes.Name, "")) + canonical_alias_id = prev_state_ids.get((EventTypes.CanonicalAlias, "")) + if not name_id and not canonical_alias_id: + # Also include the hero members of the room (for DM rooms without a title). + # To do this properly, we should select the correct subset of membership events + # from `prev_state_ids`. Instead, we are lazier and use the (cached) + # `get_room_summary` function, which is based on the current state of the room. + # This introduces races; we choose to ignore them because a) they should be rare + # and b) even if it's wrong, joining servers will get the full state eventually. + heroes = extract_heroes_from_room_summary(summary, join_event.state_key) + for hero in heroes: + membership_event_id = prev_state_ids.get((EventTypes.Member, hero)) + if membership_event_id: + state_event_ids.add(membership_event_id) return state_event_ids diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1db5d68021..259456b55d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -41,6 +41,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts +from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -805,18 +806,6 @@ class SyncHandler: if canonical_alias and canonical_alias.content.get("alias"): return summary - me = sync_config.user.to_string() - - joined_user_ids = [ - r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me - ] - invited_user_ids = [ - r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me - ] - gone_user_ids = [ - r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me - ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] - # FIXME: only build up a member_ids list for our heroes member_ids = {} for membership in ( @@ -828,11 +817,8 @@ class SyncHandler: for user_id, event_id in details.get(membership, empty_ms).members: member_ids[user_id] = event_id - # FIXME: order by stream ordering rather than as returned by SQL - if joined_user_ids or invited_user_ids: - summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5] - else: - summary["m.heroes"] = sorted(gone_user_ids)[0:5] + me = sync_config.user.to_string() + summary["m.heroes"] = extract_heroes_from_room_summary(details, me) if not sync_config.filter_collection.lazy_load_members(): return summary diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e56a13f21e..f02c1d7ea7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1517,6 +1517,36 @@ class RoomMemberStore( await self.db_pool.runInteraction("forget_membership", f) +def extract_heroes_from_room_summary( + details: Mapping[str, MemberSummary], me: str +) -> List[str]: + """Determine the users that represent a room, from the perspective of the `me` user. + + The rules which say which users we select are specified in the "Room Summary" + section of + https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync + + Returns a list (possibly empty) of heroes' mxids. + """ + empty_ms = MemberSummary([], 0) + + joined_user_ids = [ + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me + ] + invited_user_ids = [ + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me + ] + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] + + # FIXME: order by stream ordering rather than as returned by SQL + if joined_user_ids or invited_user_ids: + return sorted(joined_user_ids + invited_user_ids)[0:5] + else: + return sorted(gone_user_ids)[0:5] + + @attr.s(slots=True, auto_attribs=True) class _JoinedHostsCache: """The cached data used by the `_get_joined_hosts_cache`.""" diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 3a6ef221ae..177e5b5afc 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -212,7 +212,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): self.assertEqual(r[("m.room.member", joining_user)].membership, "join") @override_config({"experimental_features": {"msc3706_enabled": True}}) - def test_send_join_partial_state(self): + def test_send_join_partial_state(self) -> None: """When MSC3706 support is enabled, /send_join should return partial state""" joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME join_result = self._make_join(joining_user) @@ -240,6 +240,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ("m.room.power_levels", ""), ("m.room.join_rules", ""), ("m.room.history_visibility", ""), + # Users included here because they're heroes. + ("m.room.member", "@kermit:test"), + ("m.room.member", "@fozzie:test"), ], ) @@ -249,9 +252,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ] self.assertCountEqual( returned_auth_chain_events, - [ - ("m.room.member", "@kermit:test"), - ], + # TODO: change the test so that we get at least one event in the auth chain + # here. + [], ) # the room should show that the new user is a member -- cgit 1.5.1 From 5cb6ad3b87caaadaedc3cc57e5513feb459b519d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Nov 2022 11:14:38 +0000 Subject: Fix HTML templates missing correct HTML tags (#14448) --- changelog.d/14448.bugfix | 1 + synapse/res/templates/invalid_token.html | 1 + synapse/res/templates/notif_mail.html | 2 ++ synapse/res/templates/password_reset.html | 1 + synapse/res/templates/password_reset_confirmation.html | 1 + synapse/res/templates/password_reset_failure.html | 1 + synapse/res/templates/password_reset_success.html | 1 + synapse/res/templates/recaptcha.html | 1 + synapse/res/templates/registration.html | 1 + synapse/res/templates/registration_failure.html | 1 + synapse/res/templates/registration_success.html | 1 + synapse/res/templates/registration_token.html | 1 + synapse/res/templates/sso_account_deactivated.html | 1 + synapse/res/templates/sso_auth_account_details.html | 1 + synapse/res/templates/sso_auth_bad_user.html | 1 + synapse/res/templates/sso_auth_confirm.html | 1 + synapse/res/templates/sso_auth_success.html | 1 + synapse/res/templates/sso_error.html | 1 + synapse/res/templates/sso_login_idp_picker.html | 1 + synapse/res/templates/sso_new_user_consent.html | 1 + synapse/res/templates/sso_redirect_confirm.html | 1 + synapse/res/templates/terms.html | 1 + 22 files changed, 23 insertions(+) create mode 100644 changelog.d/14448.bugfix (limited to 'synapse') diff --git a/changelog.d/14448.bugfix b/changelog.d/14448.bugfix new file mode 100644 index 0000000000..4bf1c183f6 --- /dev/null +++ b/changelog.d/14448.bugfix @@ -0,0 +1 @@ +Fix rendering of some HTML templates (including emails). Introduced in v1.71.0. diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html index d0b1dae669..b19e3023a1 100644 --- a/synapse/res/templates/invalid_token.html +++ b/synapse/res/templates/invalid_token.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Invalid renewal token.{% endblock %} {% block body %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index 939d40315f..2add9dd859 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -1,3 +1,5 @@ +{% extends "_base.html" %} + {% block title %}New activity in room{% endblock %} {% block header %} diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html index de5a9ec68f..1f267946c8 100644 --- a/synapse/res/templates/password_reset.html +++ b/synapse/res/templates/password_reset.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Password reset{% endblock %} {% block body %} diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html index 0eac64b6a8..fabb9a6ed5 100644 --- a/synapse/res/templates/password_reset_confirmation.html +++ b/synapse/res/templates/password_reset_confirmation.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Password reset confirmation{% endblock %} {% block body %} diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html index 977babdb40..9990e860f9 100644 --- a/synapse/res/templates/password_reset_failure.html +++ b/synapse/res/templates/password_reset_failure.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Password reset failure{% endblock %} {% block body %} diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html index 0e99fad7ff..edada513ab 100644 --- a/synapse/res/templates/password_reset_success.html +++ b/synapse/res/templates/password_reset_success.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Password reset success{% endblock %} {% block body %} diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html index feaf3f6aed..8204928cdf 100644 --- a/synapse/res/templates/recaptcha.html +++ b/synapse/res/templates/recaptcha.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Authentication{% endblock %} {% block header %} diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html index 189960a832..cdb815665e 100644 --- a/synapse/res/templates/registration.html +++ b/synapse/res/templates/registration.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Registration{% endblock %} {% block body %} diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html index 3debe9301d..ae2a9cae2c 100644 --- a/synapse/res/templates/registration_failure.html +++ b/synapse/res/templates/registration_failure.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Registration failure{% endblock %} {% block body %} diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html index e2dd020a9e..6d45111796 100644 --- a/synapse/res/templates/registration_success.html +++ b/synapse/res/templates/registration_success.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Your email has now been validated{% endblock %} {% block body %} diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html index 2ee5866ba5..ee4e5295e7 100644 --- a/synapse/res/templates/registration_token.html +++ b/synapse/res/templates/registration_token.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Authentication{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html index c634229840..b85d96cc74 100644 --- a/synapse/res/templates/sso_account_deactivated.html +++ b/synapse/res/templates/sso_account_deactivated.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}SSO account deactivated{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index b516333373..11636d7f5d 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Create your account{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html index 69fdcc9ef0..819d79a461 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Authentication failed{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index 2d106e0ae4..3927d6eda3 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Confirm it's you{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html index 56150eaefe..afeffb7191 100644 --- a/synapse/res/templates/sso_auth_success.html +++ b/synapse/res/templates/sso_auth_success.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Authentication successful{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index e394a92623..6fa36c11c9 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Authentication failed{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index a2772ca9ef..58b0b3121c 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Choose identity provider{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html index 126887d26c..fda29928d1 100644 --- a/synapse/res/templates/sso_new_user_consent.html +++ b/synapse/res/templates/sso_new_user_consent.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Agree to terms and conditions{% endblock %} {% block header %} diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index 887ee0d294..cc2e7b3a5b 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Continue to your account{% endblock %} {% block header %} diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html index 977c3d0bc7..ffabebdd8b 100644 --- a/synapse/res/templates/terms.html +++ b/synapse/res/templates/terms.html @@ -1,3 +1,4 @@ +{% extends "_base.html" %} {% block title %}Authentication{% endblock %} {% block header %} -- cgit 1.5.1 From 945a0928c793c0bd8573e179583d983187e5f392 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Nov 2022 12:09:33 +0000 Subject: Don't filter state in /context response (#14461) We don't filter state usually, so doing so here is a waste of time. This is not much of an issue for clients that enable lazy loading of members, since there will be fewer state events. --- changelog.d/14461.misc | 1 + synapse/handlers/room.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14461.misc (limited to 'synapse') diff --git a/changelog.d/14461.misc b/changelog.d/14461.misc new file mode 100644 index 0000000000..cdfa577a4c --- /dev/null +++ b/changelog.d/14461.misc @@ -0,0 +1 @@ +Improve performance of `/context` in large rooms. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 66a50bca6e..6dcfd86fdf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1451,7 +1451,7 @@ class RoomContextHandler: events_before=events_before, event=event, events_after=events_after, - state=await filter_evts(state_events), + state=state_events, aggregations=aggregations, start=await token.copy_and_replace( StreamKeyType.ROOM, results.start -- cgit 1.5.1 From d63814fd736fed5d3d45ff3af5e6d3bfae50c439 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Nov 2022 13:50:07 +0000 Subject: Revert "Remove slaved id tracker (#14376)" (#14463) This reverts commit 36097e88c4da51fce6556a58c49bd675f4cf20ab. --- changelog.d/14376.misc | 1 - synapse/replication/slave/__init__.py | 13 ++++++ synapse/replication/slave/storage/__init__.py | 13 ++++++ .../slave/storage/_slaved_id_tracker.py | 50 ++++++++++++++++++++++ synapse/storage/databases/main/account_data.py | 30 ++++++++----- synapse/storage/databases/main/devices.py | 36 ++++++++++------ synapse/storage/databases/main/events_worker.py | 35 +++++++++------ synapse/storage/databases/main/push_rule.py | 17 ++++---- synapse/storage/databases/main/pusher.py | 24 +++++++---- synapse/storage/databases/main/receipts.py | 18 ++++---- synapse/storage/util/id_generators.py | 13 ++---- 11 files changed, 176 insertions(+), 74 deletions(-) delete mode 100644 changelog.d/14376.misc create mode 100644 synapse/replication/slave/__init__.py create mode 100644 synapse/replication/slave/storage/__init__.py create mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'synapse') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc deleted file mode 100644 index 2ca326fea6..0000000000 --- a/changelog.d/14376.misc +++ /dev/null @@ -1 +0,0 @@ -Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 0000000000..f43a360a80 --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..f43a360a80 --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 0000000000..8f3f953ed4 --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,50 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple + +from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id + + +class SlavedIdTracker(AbstractStreamIdTracker): + """Tracks the "current" stream ID of a stream with a single writer. + + See `AbstractStreamIdTracker` for more details. + + Note that this class does not work correctly when there are multiple + writers. + """ + + def __init__( + self, + db_conn: LoggingDatabaseConnection, + table: str, + column: str, + extra_tables: Optional[List[Tuple[str, str]]] = None, + step: int = 1, + ): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + if extra_tables: + for table, column in extra_tables: + self.advance(None, _load_current_id(db_conn, table, column)) + + def advance(self, instance_name: Optional[str], new_id: int) -> None: + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self) -> int: + return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 282687ebce..c38b8a9e5a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -67,11 +68,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) if isinstance(database.engine, PostgresEngine): + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) + self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -93,13 +95,21 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - is_writer=self._instance_name in hs.config.worker.writers.account_data, - ) + if self._instance_name in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + else: + self._account_data_id_gen = SlavedIdTracker( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3e5c16b15b..aa58c2adc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,6 +38,7 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -85,19 +86,28 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - is_writer=hs.config.worker.worker_app is None, - ) + if hs.config.worker.worker_app is None: + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) + else: + self._device_list_id_gen = SlavedIdTracker( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 296e50d661..467d20253d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,6 +59,7 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -212,20 +213,26 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - is_writer=hs.get_instance_name() in hs.config.worker.writers.events, - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - is_writer=hs.get_instance_name() in hs.config.worker.writers.events, - ) + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 12ad44dbb3..8ae10f6127 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,6 +30,7 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -110,14 +111,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "push_rules_stream", - "stream_id", - is_writer=hs.config.worker.worker_app is None, - ) + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index fee37b9ce4..4a01562d45 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -58,15 +59,20 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - is_writer=hs.config.worker.worker_app is None, - ) + if hs.config.worker.worker_app is None: + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) + else: + self._pushers_id_gen = SlavedIdTracker( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 64519587f8..dc6989527e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.api.constants import EduTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -60,9 +61,6 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() - - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -89,12 +87,14 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - self._receipts_id_gen = StreamIdGenerator( - db_conn, - "receipts_linearized", - "stream_id", - is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, - ) + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + else: + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1af0af1266..2dfe4c0b66 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,13 +186,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, - is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) - self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -206,11 +204,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # Advance should never be called on a writer instance, only over replication - if self._is_writer: - raise Exception("Replication is not supported by writer StreamIdGenerator") - - self._current = (max if self._step > 0 else min)(self._current, new_id) + # `StreamIdGenerator` should only be used when there is a single writer, + # so replication should never happen. + raise Exception("Replication is not supported by StreamIdGenerator") def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -253,9 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: - if self._is_writer: - return self._current - with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step -- cgit 1.5.1 From 882277008c7b43ab26e3445ab94a38aa25ad0965 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:01:22 +0000 Subject: Fix background updates failing to add unique indexes on receipts (#14453) As part of the database migration to support threaded receipts, there is a possible window in between `73/08thread_receipts_non_null.sql.postgres` removing the original unique constraints on `receipts_linearized` and `receipts_graph` and the `reeipts_linearized_unique_index` and `receipts_graph_unique_index` background updates from `72/08thread_receipts.sql` completing where the unique constraints on `receipts_linearized` and `receipts_graph` are missing. Any emulated upserts on these tables must therefore be performed with a lock held, otherwise duplicate rows can end up in the tables when there are concurrent emulated upserts. Fix the missing lock. Note that emulated upserts no longer happen by default on sqlite, since the minimum supported version of sqlite supports native upserts by default now. Finally, clean up any duplicate receipts that may have crept in before trying to create the `receipts_graph_unique_index` and `receipts_linearized_unique_index` unique indexes. Signed-off-by: Sean Quah --- changelog.d/14453.bugfix | 1 + synapse/storage/databases/main/receipts.py | 171 ++++++++++++++++++--- tests/storage/databases/main/test_receipts.py | 209 ++++++++++++++++++++++++++ 3 files changed, 357 insertions(+), 24 deletions(-) create mode 100644 changelog.d/14453.bugfix create mode 100644 tests/storage/databases/main/test_receipts.py (limited to 'synapse') diff --git a/changelog.d/14453.bugfix b/changelog.d/14453.bugfix new file mode 100644 index 0000000000..4969e5450c --- /dev/null +++ b/changelog.d/14453.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where the background updates to add non-thread unique indexes on receipts could fail when upgrading from 1.67.0 or earlier. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dc6989527e..fbf27497ec 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,24 +113,6 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - self.db_pool.updates.register_background_index_update( - "receipts_linearized_unique_index", - index_name="receipts_linearized_unique_index", - table="receipts_linearized", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - - self.db_pool.updates.register_background_index_update( - "receipts_graph_unique_index", - index_name="receipts_graph_unique_index", - table="receipts_graph", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -702,9 +684,6 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_linearized has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) return rx_ts @@ -862,14 +841,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_graph has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) class ReceiptsBackgroundUpdateStore(SQLBaseStore): POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index" + RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index" def __init__( self, @@ -883,6 +861,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, self._populate_receipt_event_stream_ordering, ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_linearized_unique_index, + ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_graph_unique_index, + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int @@ -938,6 +924,143 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _create_receipts_index(self, index_name: str, table: str) -> None: + """Adds a unique index on `(room_id, receipt_type, user_id)` to the given + receipts table, for non-thread receipts.""" + + def _create_index(conn: LoggingDatabaseConnection) -> None: + conn.rollback() + + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=True) + + try: + c = conn.cursor() + + # Now that the duplicates are gone, we can create the index. + concurrently = ( + "CONCURRENTLY" + if isinstance(self.database_engine, PostgresEngine) + else "" + ) + sql = f""" + CREATE UNIQUE INDEX {concurrently} {index_name} + ON {table}(room_id, receipt_type, user_id) + WHERE thread_id IS NULL + """ + c.execute(sql) + finally: + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=False) + + await self.db_pool.runWithConnection(_create_index) + + async def _background_receipts_linearized_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT MAX(stream_id), room_id, receipt_type, user_id + FROM receipts_linearized + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) + + # Then remove duplicate receipts, keeping the one with the highest + # `stream_id`. There should only be a single receipt with any given + # `stream_id`. + for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id < ? + """ + txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_linearized_unique_index", + "receipts_linearized", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + + async def _background_receipts_graph_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT room_id, receipt_type, user_id FROM receipts_graph + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[str, str, str]], list(txn)) + + # Then remove all duplicate receipts. + # We could be clever and try to keep the latest receipt out of every set of + # duplicates, but it's far simpler to remove them all. + for room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_graph + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL + """ + txn.execute(sql, (room_id, receipt_type, user_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_graph_unique_index", + "receipts_graph", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py new file mode 100644 index 0000000000..c4f12d81d7 --- /dev/null +++ b/tests/storage/databases/main/test_receipts.py @@ -0,0 +1,209 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Sequence, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store = hs.get_datastores().main + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + self.other_room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def _test_background_receipts_unique_index( + self, + update_name: str, + index_name: str, + table: str, + receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], + expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], + ): + """Test that the background update to uniqueify non-thread receipts in + the given receipts table works properly. + + Args: + update_name: The name of the background update to test. + index_name: The name of the index that the background update creates. + table: The table of receipts that the background update fixes. + receipts: The test data containing duplicate receipts. + A list of receipt rows to insert, grouped by + `(room_id, receipt_type, user_id)`. + expected_unique_receipts: A dictionary of `(room_id, receipt_type, user_id)` + keys and expected receipt key-values after duplicate receipts have been + removed. + """ + # First, undo the background update. + def drop_receipts_unique_index(txn: LoggingTransaction) -> None: + txn.execute(f"DROP INDEX IF EXISTS {index_name}") + + self.get_success( + self.store.db_pool.runInteraction( + "drop_receipts_unique_index", + drop_receipts_unique_index, + ) + ) + + # Populate the receipts table, including duplicates. + for (room_id, receipt_type, user_id), rows in receipts.items(): + for row in rows: + self.get_success( + self.store.db_pool.simple_insert( + table, + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "thread_id": None, + "data": "{}", + **row, + }, + ) + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": update_name, + "progress_json": "{}", + }, + ) + ) + + self.store.db_pool.updates._all_done = False + + self.wait_for_background_updates() + + # Check that the remaining receipts match expectations. + for ( + room_id, + receipt_type, + user_id, + ), expected_row in expected_unique_receipts.items(): + # Include the receipt key in the returned columns, for more informative + # assertion messages. + columns = ["room_id", "receipt_type", "user_id"] + if expected_row is not None: + columns += expected_row.keys() + + rows = self.get_success( + self.store.db_pool.simple_select_list( + table=table, + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + # `simple_select_onecol` does not support NULL filters, + # so skip the filter on `thread_id`. + }, + retcols=columns, + desc="get_receipt", + ) + ) + + if expected_row is not None: + self.assertEqual( + len(rows), + 1, + f"Background update did not leave behind latest receipt in {table}", + ) + self.assertEqual( + rows[0], + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + **expected_row, + }, + ) + else: + self.assertEqual( + len(rows), + 0, + f"Background update did not remove all duplicate receipts from {table}", + ) + + def test_background_receipts_linearized_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_linearized` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_linearized_unique_index", + "receipts_linearized_unique_index", + "receipts_linearized", + receipts={ + (self.room_id, "m.read", self.user_id): [ + {"stream_id": 5, "event_id": "$some_event"}, + {"stream_id": 6, "event_id": "$some_event"}, + ], + (self.other_room_id, "m.read", self.user_id): [ + {"stream_id": 7, "event_id": "$some_event"} + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): {"stream_id": 6}, + (self.other_room_id, "m.read", self.user_id): {"stream_id": 7}, + }, + ) + + def test_background_receipts_graph_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_graph` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_graph_unique_index", + "receipts_graph_unique_index", + "receipts_graph", + receipts={ + (self.room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + }, + { + "event_ids": '["$some_event"]', + }, + ], + (self.other_room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + } + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): None, + (self.other_room_id, "m.read", self.user_id): { + "event_ids": '["$some_event"]' + }, + }, + ) -- cgit 1.5.1 From d8cc86eff484b6f570f55a5badb337080c6e4dcd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Nov 2022 10:25:24 -0500 Subject: Remove redundant types from comments. (#14412) Remove type hints from comments which have been added as Python type hints. This helps avoid drift between comments and reality, as well as removing redundant information. Also adds some missing type hints which were simple to fill in. --- changelog.d/14412.misc | 1 + synapse/api/errors.py | 2 +- synapse/config/logger.py | 5 ++- synapse/crypto/keyring.py | 9 +++-- synapse/events/__init__.py | 3 +- synapse/federation/transport/client.py | 11 +++--- synapse/federation/transport/server/_base.py | 4 +-- synapse/handlers/e2e_keys.py | 2 +- synapse/handlers/e2e_room_keys.py | 5 +-- synapse/handlers/federation.py | 4 +-- synapse/handlers/identity.py | 2 +- synapse/handlers/oidc.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/saml.py | 4 +-- synapse/http/additional_resource.py | 3 +- synapse/http/federation/matrix_federation_agent.py | 9 +++-- synapse/http/matrixfederationclient.py | 3 +- synapse/http/proxyagent.py | 20 +++++------ synapse/http/server.py | 2 +- synapse/http/site.py | 2 +- synapse/logging/context.py | 39 +++++++++++----------- synapse/logging/opentracing.py | 4 +-- synapse/module_api/__init__.py | 7 ++-- synapse/replication/http/_base.py | 2 +- synapse/rest/admin/users.py | 5 +-- synapse/rest/client/login.py | 2 +- synapse/rest/media/v1/media_repository.py | 4 +-- synapse/rest/media/v1/thumbnailer.py | 4 +-- synapse/server_notices/consent_server_notices.py | 5 ++- .../resource_limits_server_notices.py | 12 ++++--- synapse/storage/controllers/persist_events.py | 5 ++- synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/e2e_room_keys.py | 8 ++--- synapse/storage/databases/main/end_to_end_keys.py | 7 ++-- synapse/storage/databases/main/events.py | 22 ++++++------ synapse/storage/databases/main/events_worker.py | 2 +- .../storage/databases/main/monthly_active_users.py | 8 ++--- synapse/storage/databases/main/registration.py | 6 ++-- synapse/storage/databases/main/room.py | 8 +++-- synapse/storage/databases/main/user_directory.py | 9 +++-- synapse/types.py | 4 +-- synapse/util/async_helpers.py | 3 +- synapse/util/caches/__init__.py | 2 +- synapse/util/caches/deferred_cache.py | 2 +- synapse/util/caches/dictionary_cache.py | 9 ++--- synapse/util/caches/expiringcache.py | 2 +- synapse/util/caches/lrucache.py | 8 ++--- synapse/util/ratelimitutils.py | 2 +- synapse/util/threepids.py | 2 +- synapse/util/wheel_timer.py | 4 +-- tests/http/__init__.py | 7 ++-- tests/replication/slave/storage/test_events.py | 7 ++-- tests/replication/test_multi_media_repo.py | 14 ++++---- .../test_resource_limits_server_notices.py | 10 +++--- tests/unittest.py | 18 +++++----- 55 files changed, 174 insertions(+), 176 deletions(-) create mode 100644 changelog.d/14412.misc (limited to 'synapse') diff --git a/changelog.d/14412.misc b/changelog.d/14412.misc new file mode 100644 index 0000000000..4da061d461 --- /dev/null +++ b/changelog.d/14412.misc @@ -0,0 +1 @@ +Remove duplicated type information from type hints. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 400dd12aba..e2cfcea0f2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -713,7 +713,7 @@ class HttpResponseException(CodeMessageException): set to the reason code from the HTTP response. Returns: - SynapseError: + The error converted to a SynapseError. """ # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 94d1150415..5468b963a2 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -317,10 +317,9 @@ def setup_logging( Set up the logging subsystem. Args: - config (LoggingConfig | synapse.config.worker.WorkerConfig): - configuration data + config: configuration data - use_worker_options (bool): True to use the 'worker_log_config' option + use_worker_options: True to use the 'worker_log_config' option instead of 'log_config'. logBeginner: The Twisted logBeginner to use. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c88afb2986..dd9b8089ec 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -213,7 +213,7 @@ class Keyring: def verify_json_objects_for_server( self, server_and_json: Iterable[Tuple[str, dict, int]] - ) -> List[defer.Deferred]: + ) -> List["defer.Deferred[None]"]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. @@ -226,10 +226,9 @@ class Keyring: valid. Returns: - List: for each input triplet, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. + For each input triplet, a deferred indicating success or failure to + verify each json object's signature for the given server_name. The + deferreds run their callbacks in the sentinel logcontext. """ return [ run_in_background( diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 030c3ca408..8aca9a3ab9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -597,8 +597,7 @@ def _event_type_from_format_version( format_version: The event format version Returns: - type: A type that can be initialized as per the initializer of - `FrozenEvent` + A type that can be initialized as per the initializer of `FrozenEvent` """ if format_version == EventFormatVersions.ROOM_V1_V2: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd39d4d111..a3cfc701cd 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -280,12 +280,11 @@ class TransportLayerClient: Note that this does not append any events to any graphs. Args: - destination (str): address of remote homeserver - room_id (str): room to join/leave - user_id (str): user to be joined/left - membership (str): one of join/leave - params (dict[str, str|Iterable[str]]): Query parameters to include in the - request. + destination: address of remote homeserver + room_id: room to join/leave + user_id: user to be joined/left + membership: one of join/leave + params: Query parameters to include in the request. Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 1db8009d6c..cdaf0d5de7 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -224,10 +224,10 @@ class BaseFederationServlet: With arguments: - origin (unicode|None): The authenticated server_name of the calling server, + origin (str|None): The authenticated server_name of the calling server, unless REQUIRE_AUTH is set to False and authentication failed. - content (unicode|None): decoded json body of the request. None if the + content (str|None): decoded json body of the request. None if the request was a GET. query (dict[bytes, list[bytes]]): Query params from the request. url-decoded diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a9912c467d..bf1221f523 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -870,7 +870,7 @@ class E2eKeysHandler: - signatures of the user's master key by the user's devices. Args: - user_id (string): the user uploading the keys + user_id: the user uploading the keys signatures (dict[string, dict]): map of devices to signed keys Returns: diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 28dc08c22a..83f53ceb88 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -377,8 +377,9 @@ class E2eRoomKeysHandler: """Deletes a given version of the user's e2e_room_keys backup Args: - user_id(str): the user whose current backup version we're deleting - version(str): the version id of the backup being deleted + user_id: the user whose current backup version we're deleting + version: Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. Raises: NotFoundError: if this backup version doesn't exist """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5fc3b8bc8c..188f0956ef 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1596,8 +1596,8 @@ class FederationHandler: Fetch the complexity of a remote room over federation. Args: - remote_room_hosts (list[str]): The remote servers to ask. - room_id (str): The room ID to ask about. + remote_room_hosts: The remote servers to ask. + room_id: The room ID to ask about. Returns: Dict contains the complexity diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 93d09e9939..848e46eb9b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -711,7 +711,7 @@ class IdentityHandler: inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. - id_access_token (str): The access token to authenticate to the identity + id_access_token: The access token to authenticate to the identity server with Returns: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 867973dcca..41c675f408 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -787,7 +787,7 @@ class OidcProvider: Must include an ``access_token`` field. Returns: - UserInfo: an object representing the user. + an object representing the user. """ logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0066d63987..b7bc787636 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC): """Get the current presence state for multiple users. Returns: - dict: `user_id` -> `UserPresenceState` + A mapping of `user_id` -> `UserPresenceState` """ states = {} missing = [] diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 9602f0d0bb..874860d461 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -441,7 +441,7 @@ class DefaultSamlMappingProvider: client_redirect_url: where the client wants to redirect to Returns: - dict: A dict containing new user attributes. Possible keys: + A dict containing new user attributes. Possible keys: * mxid_localpart (str): Required. The localpart of the user's mxid * displayname (str): The displayname of the user * emails (list[str]): Any emails for the user @@ -483,7 +483,7 @@ class DefaultSamlMappingProvider: Args: config: A dictionary containing configuration options for this provider Returns: - SamlConfig: A custom config object for this module + A custom config object for this module """ # Parse config options and use defaults where necessary mxid_source_attribute = config.get("mxid_source_attribute", "uid") diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 6a9f6635d2..8729630581 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -45,8 +45,7 @@ class AdditionalResource(DirectServeJsonResource): Args: hs: homeserver - handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): - function to be called to handle the request. + handler: function to be called to handle the request. """ super().__init__() self._handler = handler diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 2f0177f1e2..0359231e7d 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -155,11 +155,10 @@ class MatrixFederationAgent: a file for a file upload). Or None if the request is to have no body. Returns: - Deferred[twisted.web.iweb.IResponse]: - fires when the header of the response has been received (regardless of the - response status code). Fails if there is any problem which prevents that - response from being received (including problems that prevent the request - from being sent). + A deferred which fires when the header of the response has been received + (regardless of the response status code). Fails if there is any problem + which prevents that response from being received (including problems that + prevent the request from being sent). """ # We use urlparse as that will set `port` to None if there is no # explicit port. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3c35b1d2c7..b92f1d3d1a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -951,8 +951,7 @@ class MatrixFederationHttpClient: args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The - result will be the decoded JSON body. + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: HttpResponseException: If we get an HTTP response code >= 300 diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 1f8227896f..18899bc6d1 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -34,7 +34,7 @@ from twisted.web.client import ( ) from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS +from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials @@ -134,7 +134,7 @@ class ProxyAgent(_AgentBase): uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. @@ -157,17 +157,17 @@ class ProxyAgent(_AgentBase): a file upload). Or, None if the request is to have no body. Returns: - Deferred[IResponse]: completes when the header of the response has - been received (regardless of the response status code). + A deferred which completes when the header of the response has + been received (regardless of the response status code). - Can fail with: - SchemeNotSupported: if the uri is not http or https + Can fail with: + SchemeNotSupported: if the uri is not http or https - twisted.internet.error.TimeoutError if the server we are connecting - to (proxy or destination) does not accept a connection before - connectTimeout. + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. - ... other things too. + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/http/server.py b/synapse/http/server.py index b26e34bceb..051a1899a0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -267,7 +267,7 @@ class HttpServer(Protocol): request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. This should return either tuple of (code, response), or None. - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ diff --git a/synapse/http/site.py b/synapse/http/site.py index 3dbd541fed..6a1dbf7f33 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -400,7 +400,7 @@ class SynapseRequest(Request): be sure to call finished_processing. Args: - servlet_name (str): the name of the servlet which will be + servlet_name: the name of the servlet which will be processing this request. This is used in the metrics. It is possible to update this afterwards by updating diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 6a08ffed64..f62bea968f 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -117,8 +117,7 @@ class ContextResourceUsage: """Create a new ContextResourceUsage Args: - copy_from (ContextResourceUsage|None): if not None, an object to - copy stats from + copy_from: if not None, an object to copy stats from """ if copy_from is None: self.reset() @@ -162,7 +161,7 @@ class ContextResourceUsage: """Add another ContextResourceUsage's stats to this one's. Args: - other (ContextResourceUsage): the other resource usage object + other: the other resource usage object """ self.ru_utime += other.ru_utime self.ru_stime += other.ru_stime @@ -342,7 +341,7 @@ class LoggingContext: called directly. Returns: - LoggingContext: the current logging context + The current logging context """ warnings.warn( "synapse.logging.context.LoggingContext.current_context() is deprecated " @@ -362,7 +361,8 @@ class LoggingContext: called directly. Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -474,8 +474,7 @@ class LoggingContext: """Get resources used by this logcontext so far. Returns: - ContextResourceUsage: a *copy* of the object tracking resource - usage so far + A *copy* of the object tracking resource usage so far """ # we always return a copy, for consistency res = self._resource_usage.copy() @@ -663,7 +662,8 @@ def current_context() -> LoggingContextOrSentinel: def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: """Set the current logging context in thread local storage Args: - context(LoggingContext): The context to activate. + context: The context to activate. + Returns: The context that was previously active """ @@ -700,7 +700,7 @@ def nested_logging_context(suffix: str) -> LoggingContext: suffix: suffix to add to the parent context's 'name'. Returns: - LoggingContext: new logging context. + A new logging context. """ curr_context = current_context() if not curr_context: @@ -898,20 +898,19 @@ def defer_to_thread( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked, and whose threadpool we should use for the - function. + reactor: The reactor in whose main thread the Deferred will be invoked, + and whose threadpool we should use for the function. Normally this will be hs.get_reactor(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) @@ -939,20 +938,20 @@ def defer_to_threadpool( on it. Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked. Normally this will be hs.get_reactor(). + reactor: The reactor in whose main thread the Deferred will be invoked. + Normally this will be hs.get_reactor(). - threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for - running `f`. Normally this will be hs.get_reactor().getThreadPool(). + threadpool: The threadpool to use for running `f`. Normally this will be + hs.get_reactor().getThreadPool(). - f (callable): The function to call. + f: The function to call. args: positional arguments to pass to f. kwargs: keyword arguments to pass to f. Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an + A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ curr_context = current_context() diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ce5a2a338..b69060854f 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -721,7 +721,7 @@ def inject_header_dict( destination: address of entity receiving the span context. Must be given unless check_destination is False. The context will only be injected if the destination matches the opentracing whitelist - check_destination (bool): If false, destination will be ignored and the context + check_destination: If false, destination will be ignored and the context will always be injected. Note: @@ -780,7 +780,7 @@ def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str destination: the name of the remote server. Returns: - dict: the active span's context if opentracing is enabled, otherwise empty. + the active span's context if opentracing is enabled, otherwise empty. """ if destination and not whitelisted_homeserver(destination): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 30e689d00d..1adc1fd64f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -787,7 +787,7 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - access_token(str): access token + access_token: access token Returns: twisted.internet.defer.Deferred - resolves once the access token @@ -832,7 +832,7 @@ class ModuleApi: **kwargs: named args to be passed to func Returns: - Deferred[object]: result of func + Result of func """ # type-ignore: See https://github.com/python/mypy/issues/8862 return defer.ensureDeferred( @@ -924,8 +924,7 @@ class ModuleApi: to represent 'any') of the room state to acquire. Returns: - twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: - The filtered state events in the room. + The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( self._storage_controllers.state.get_current_state_ids( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 5e661f8c73..3f4d3fc51a 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -153,7 +153,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): argument list. Returns: - dict: If POST/PUT request then dictionary must be JSON serialisable, + If POST/PUT request then dictionary must be JSON serialisable, otherwise must be appropriate for adding as query args. """ return {} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1951b8a9f2..6e0c44be2a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -903,8 +903,9 @@ class PushersRestServlet(RestServlet): @user:server/pushers Returns: - pushers: Dictionary containing pushers information. - total: Number of pushers in dictionary `pushers`. + A dictionary with keys: + pushers: Dictionary containing pushers information. + total: Number of pushers in dictionary `pushers`. """ PATTERNS = admin_patterns("/users/(?P[^/]*)/pushers$") diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 05706b598c..8adced41e5 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -350,7 +350,7 @@ class LoginRestServlet(RestServlet): auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: - result: Dictionary of account information after successful login. + Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 328c0c5477..40b0d39eb2 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -344,8 +344,8 @@ class MediaRepository: download from remote server. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 9b93b9b4f6..a48a4de92a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -138,7 +138,7 @@ class Thumbnailer: """Rescales the image to the given dimensions. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ with self._resize(width, height) as scaled: return self._encode_image(scaled, output_type) @@ -155,7 +155,7 @@ class Thumbnailer: max_height: The largest possible height. Returns: - BytesIO: the bytes of the encoded image ready to be written to disk + The bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_width = width diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 698ca742ed..94025ba41f 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -113,9 +113,8 @@ def copy_with_str_subst(x: Any, substitutions: Any) -> Any: """Deep-copy a structure, carrying out string substitutions on any strings Args: - x (object): structure to be copied - substitutions (object): substitutions to be made - passed into the - string '%' operator + x: structure to be copied + substitutions: substitutions to be made - passed into the string '%' operator Returns: copy of x diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 3134cd2d3d..a31a2c99a7 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -170,11 +170,13 @@ class ResourceLimitsServerNotices: room_id: The room id of the server notices room Returns: - bool: Is the room currently blocked - list: The list of pinned event IDs that are unrelated to limit blocking - This list can be used as a convenience in the case where the block - is to be lifted and the remaining pinned event references need to be - preserved + Tuple of: + Is the room currently blocked + + The list of pinned event IDs that are unrelated to limit blocking + This list can be used as a convenience in the case where the block + is to be lifted and the remaining pinned event references need to be + preserved """ currently_blocked = False pinned_state_event = None diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 48976dc570..33ffef521b 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -204,9 +204,8 @@ class _EventPeristenceQueue(Generic[_PersistResult]): process to to so, calling the per_item_callback for each item. Args: - room_id (str): - task (_EventPersistQueueTask): A _PersistEventsTask or - _UpdateCurrentStateTask to process. + room_id: + task: A _PersistEventsTask or _UpdateCurrentStateTask to process. Returns: the result returned by the `_per_item_callback` passed to diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index aa58c2adc3..e114c733d1 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -535,7 +535,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): limit: Maximum number of device updates to return Returns: - List: List of device update tuples: + List of device update tuples: - user_id - device_id - stream_id diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index af59be6b48..6240f9a75e 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -391,10 +391,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): Returns: A dict giving the info metadata for this backup version, with fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - etag(int): tag of the keys in the backup + version (str) + algorithm (str) + auth_data (object): opaque dict supplied by the client + etag (int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2a4f58ed92..cf33e73e2b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -412,10 +412,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Retrieve a number of one-time keys for a user Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve + user_id: id of user to get keys for + device_id: id of device to get keys for + key_ids: list of key ids (excluding algorithm) to retrieve Returns: A map from (algorithm, key_id) to json string for key diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c4acff5be6..d68f127f9b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1279,9 +1279,10 @@ class PersistEventsStore: Pick the earliest non-outlier if there is one, else the earliest one. Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts: + Returns: - list[(EventBase, EventContext)]: filtered list + filtered list """ new_events_and_contexts: OrderedDict[ str, Tuple[EventBase, EventContext] @@ -1307,9 +1308,8 @@ class PersistEventsStore: """Update min_depth for each room Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: @@ -1580,13 +1580,11 @@ class PersistEventsStore: """Update all the miscellaneous tables for new events Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + txn: db connection + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 467d20253d..8a104f7e93 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1589,7 +1589,7 @@ class EventsWorkerStore(SQLBaseStore): room_id: The room ID to query. Returns: - dict[str:float] of complexity version to complexity. + Map of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index efd136a864..db9a24db5e 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -217,7 +217,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: - reserved_users (tuple): reserved users to preserve + reserved_users: reserved users to preserve """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) @@ -370,8 +370,8 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): should not appear in the MAU stats). Args: - txn (cursor): - user_id (str): user to add/update + txn: + user_id: user to add/update """ assert ( self._update_on_this_worker @@ -401,7 +401,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): add the user to the monthly active tables Args: - user_id(str): the user_id to query + user_id: the user_id to query """ assert ( self._update_on_this_worker diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 5167089e03..31f0f2bd3d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -953,7 +953,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Returns user id from threepid Args: - txn (cursor): + txn: medium: threepid medium e.g. email address: threepid address e.g. me@example.com @@ -1283,8 +1283,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Sets an expiration date to the account with the given user ID. Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be + user_id: User ID to set an expiration date for. + use_delta: If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d97f8f60e..4fbaefad73 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2057,7 +2057,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): Args: report_id: ID of reported event in database Returns: - event_report: json list of information from event report + JSON dict of information from an event report or None if the + report does not exist. """ def _get_event_report_txn( @@ -2130,8 +2131,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: - event_reports: json list of event reports - count: total number of event reports matching the filter criteria + Tuple of: + json list of event reports + total number of event reports matching the filter criteria """ def _get_event_reports_paginate_txn( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ddb25b5cea..698d6f7515 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -185,9 +185,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - who should be in the user_directory. Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. + progress + batch_size: Maximum number of state events to process per cycle. Returns: number of events processed. @@ -708,10 +707,10 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns the rooms that a user is in. Args: - user_id(str): Must be a local user + user_id: Must be a local user Returns: - list: user_id + List of room IDs """ rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", diff --git a/synapse/types.py b/synapse/types.py index 773f0438d5..f2d436ddc3 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -143,8 +143,8 @@ class Requester: Requester. Args: - store (DataStore): Used to convert AS ID to AS object - input (dict): A dict produced by `serialize` + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` Returns: Requester diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7f1d41eb3c..d24c4f68c4 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -217,7 +217,8 @@ async def concurrently_execute( limit: Maximum number of conccurent executions. Returns: - Deferred: Resolved when all function invocations have finished. + None, when all function invocations have finished. The return values + from those functions are discarded. """ it = iter(args) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f7c3a6794e..9387632d0d 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -197,7 +197,7 @@ def register_cache( resize_callback: A function which can be called to resize the cache. Returns: - CacheMetric: an object which provides inc_{hits,misses,evictions} methods + an object which provides inc_{hits,misses,evictions} methods """ if resizable: if not resize_callback: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bcb1cba362..bf7bd351e0 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -153,7 +153,7 @@ class DeferredCache(Generic[KT, VT]): Args: key: callback: Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics + update_metrics: whether to update the cache hit rate metrics Returns: A Deferred which completes with the result. Note that this may later fail diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index fa91479c97..5eaf70c7ab 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -169,10 +169,11 @@ class DictionaryCache(Generic[KT, DKT, DV]): if it is in the cache. Returns: - DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry` - will contain include the keys that are in the cache. If None then - will either return the full dict if in the cache, or the empty - dict (with `full` set to False) if it isn't. + If `dict_keys` is not None then `DictionaryEntry` will contain include + the keys that are in the cache. + + If None then will either return the full dict if in the cache, or the + empty dict (with `full` set to False) if it isn't. """ if dict_keys is None: # The caller wants the full set of dictionary keys for this cache key diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c6a5d0dfc0..01ad02af67 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -207,7 +207,7 @@ class ExpiringCache(Generic[KT, VT]): items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ new_size = int(self._original_max_size * factor) if new_size != self._max_size: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index aa93109d13..dcf0eac3bf 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -389,11 +389,11 @@ class LruCache(Generic[KT, VT]): cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - cache_type (type): + cache_type: type of underlying cache to be used. Typically one of dict or TreeCache. - size_callback (func(V) -> int | None): + size_callback: metrics_collection_callback: metrics collection callback. This is called early in the metrics @@ -403,7 +403,7 @@ class LruCache(Generic[KT, VT]): Ignored if cache_name is None. - apply_cache_factor_from_config (bool): If true, `max_size` will be + apply_cache_factor_from_config: If true, `max_size` will be multiplied by a cache factor derived from the homeserver config clock: @@ -796,7 +796,7 @@ class LruCache(Generic[KT, VT]): items from the cache. Returns: - bool: Whether the cache changed size or not. + Whether the cache changed size or not. """ if not self.apply_cache_factor_from_config: return False diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 9f64fed0d7..2aceb1a47f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -183,7 +183,7 @@ class FederationRateLimiter: # Handle request ... Args: - host (str): Origin of incoming request. + host: Origin of incoming request. Returns: context manager which returns a deferred. diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 1e9c2faa64..54bc7589fd 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -48,7 +48,7 @@ async def check_3pid_allowed( registration: whether we want to bind the 3PID as part of registering a new user. Returns: - bool: whether the 3PID medium/address is allowed to be added to this HS + whether the 3PID medium/address is allowed to be added to this HS """ if not await hs.get_password_auth_provider().is_3pid_allowed( medium, address, registration diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 177e198e7e..b1ec7f4bd8 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -90,10 +90,10 @@ class WheelTimer(Generic[T]): """Fetch any objects that have timed out Args: - now (ms): Current time in msec + now: Current time in msec Returns: - list: List of objects that have timed out + List of objects that have timed out """ now_key = int(now / self.bucket_size) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index e74f7f5b48..093537adef 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import os.path import subprocess +from typing import List from zope.interface import implementer @@ -70,14 +71,14 @@ subjectAltName = %(sanentries)s """ -def create_test_cert_file(sanlist): +def create_test_cert_file(sanlist: List[bytes]) -> str: """build an x509 certificate file Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert Returns: - str: the path to the file + The path to the file """ global cert_file_count csr_filename = "server.csr" diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 96f3880923..dce71f7334 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -143,6 +143,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + assert event.internal_metadata.stream_ordering is not None self.replicate() @@ -230,6 +231,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): j2 = self.persist( type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) + assert j2.internal_metadata.stream_ordering is not None self.replicate() expected_pos = PersistedEventPosition( @@ -287,6 +289,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): ) ) self.replicate() + assert j2.internal_metadata.stream_ordering is not None event_source = RoomEventSource(self.hs) event_source.store = self.slaved_store @@ -336,10 +339,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - def persist(self, backfill=False, **kwargs): + def persist(self, backfill=False, **kwargs) -> FrozenEvent: """ Returns: - synapse.events.FrozenEvent: The event that was persisted. + The event that was persisted. """ event, context = self.build_event(**kwargs) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 13aa5eb51a..96cdf2c45b 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -15,8 +15,9 @@ import logging import os from typing import Optional, Tuple +from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from twisted.web.server import Request @@ -102,7 +103,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): ) # fish the test server back out of the server-side TLS protocol. - http_server = server_tls_protocol.wrappedProtocol + http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment] # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) @@ -238,16 +239,15 @@ def get_connection_factory(): return test_server_connection_factory -def _build_test_server(connection_creator): +def _build_test_server( + connection_creator: IOpenSSLServerConnectionCreator, +) -> TLSMemoryBIOProtocol: """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + connection_creator: thing to build SSL connections Returns: TLSMemoryBIOProtocol diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index bf403045e9..7cbc40736c 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -350,14 +351,15 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.assertTrue(notice_in_room, "No server notice in room") - def _trigger_notice_and_join(self): + def _trigger_notice_and_join(self) -> Tuple[str, str, str]: """Creates enough active users to hit the MAU limit and trigger a system notice about it, then joins the system notices room with one of the users created. Returns: - user_id (str): The ID of the user that joined the room. - tok (str): The access token of the user that joined the room. - room_id (str): The ID of the room that's been joined. + A tuple of: + user_id: The ID of the user that joined the room. + tok: The access token of the user that joined the room. + room_id: The ID of the room that's been joined. """ user_id = None tok = None diff --git a/tests/unittest.py b/tests/unittest.py index 5116be338e..a120c2976c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -360,13 +360,13 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock): """ Make and return a homeserver. Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. + clock: The Clock, associated with the reactor. Returns: A homeserver suitable for testing. @@ -426,9 +426,8 @@ class HomeserverTestCase(TestCase): Args: reactor: A Twisted Reactor, or something that pretends to be one. - clock (synapse.util.Clock): The Clock, associated with the reactor. - homeserver (synapse.server.HomeServer): The HomeServer to test - against. + clock: The Clock, associated with the reactor. + homeserver: The HomeServer to test against. Function to optionally be overridden in subclasses. """ @@ -452,11 +451,10 @@ class HomeserverTestCase(TestCase): given content. Args: - method (bytes/unicode): The HTTP request method ("verb"). - path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. - escaped UTF-8 & spaces and such). - content (bytes or dict): The body of the request. JSON-encoded, if - a dict. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces + and such). content (bytes or dict): The body of the request. + JSON-encoded, if a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake -- cgit 1.5.1 From 618e4ab81b70e37bdb8e9224bd84fcfe4b15bdea Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:25:35 +0000 Subject: Fix an invalid comparison of `UserPresenceState` to `str` (#14393) --- changelog.d/14393.bugfix | 1 + synapse/handlers/presence.py | 2 +- tests/handlers/test_presence.py | 41 +++++++++++++++++++++++++++++++++++------ tests/module_api/test_api.py | 3 +++ tests/replication/_base.py | 7 ++++++- 5 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14393.bugfix (limited to 'synapse') diff --git a/changelog.d/14393.bugfix b/changelog.d/14393.bugfix new file mode 100644 index 0000000000..97177bc62f --- /dev/null +++ b/changelog.d/14393.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.58.0 where a user with presence state 'org.matrix.msc3026.busy' would mistakenly be set to 'online' when calling `/sync` or `/events` on a worker process. \ No newline at end of file diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b7bc787636..cf08737d11 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -478,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler): return _NullContextManager() prev_state = await self.current_state_for_user(user_id) - if prev_state != PresenceState.BUSY: + if prev_state.state != PresenceState.BUSY: # We set state here but pass ignore_status_msg = True as we don't want to # cause the status message to be cleared. # Note that this causes last_active_ts to be incremented which is not diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c96dc6caf2..c5981ff965 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -15,6 +15,7 @@ from typing import Optional from unittest.mock import Mock, call +from parameterized import parameterized from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState @@ -37,6 +38,7 @@ from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): @@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(state, new_state) -class PresenceHandlerTestCase(unittest.HomeserverTestCase): +class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def prepare(self, reactor, clock, hs): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() @@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): # our status message should be the same as it was before self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_keeps_busy(self): - """Test that presence set by syncing doesn't affect busy status""" - # while this isn't the default - self.presence_handler._busy_presence_enabled = True + @parameterized.expand([(False,), (True,)]) + @unittest.override_config( + { + "experimental_features": { + "msc3026_enabled": True, + }, + } + ) + def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + """Test that presence set by syncing doesn't affect busy status + Args: + test_with_workers: If True, check the presence state of the user by calling + /sync against a worker, rather than the main process. + """ user_id = "@test:server" status_msg = "I'm busy!" + # By default, we call /sync against the main process. + worker_to_sync_against = self.hs + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "presence_writer"} + ) + + # Set presence to BUSY self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) + # Perform a sync with a presence state other than busy. This should NOT change + # our presence status; we only change from busy if we explicitly set it via + # /presence/*. self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + worker_to_sync_against.get_presence_handler().user_syncing( + user_id, True, PresenceState.ONLINE + ) ) + # Check against the main process that the user's presence did not change. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 02cef6f876..058ca57e55 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -778,8 +778,11 @@ def _test_sending_local_online_presence_to_local_user( worker process. The test users will still sync with the main process. The purpose of testing with a worker is to check whether a Synapse module running on a worker can inform other workers/ the main process that they should include additional presence when a user next syncs. + If this argument is True, `test_case` MUST be an instance of BaseMultiWorkerStreamTestCase. """ if test_with_workers: + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) + # Create a worker process to make module_api calls against worker_hs = test_case.make_worker_hs( "synapse.app.generic_worker", {"worker_name": "presence_writer"} diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 121f3d8d65..3029a16dda 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -542,8 +542,13 @@ class FakeRedisPubSubProtocol(Protocol): self.send("OK") elif command == b"GET": self.send(None) + + # Connection keep-alives. + elif command == b"PING": + self.send("PONG") + else: - raise Exception("Unknown command") + raise Exception(f"Unknown command: {command}") def send(self, msg): """Send a message back to the client.""" -- cgit 1.5.1 From c15e9a0edb696990365ac5a4e5be847b5ae23921 Mon Sep 17 00:00:00 2001 From: realtyem Date: Wed, 16 Nov 2022 16:16:25 -0600 Subject: Remove need for `worker_main_http_uri` setting to use /keys/upload. (#14400) --- changelog.d/14400.misc | 1 + docker/configure_workers_and_start.py | 5 +- docs/workers.md | 7 +-- synapse/app/generic_worker.py | 103 +--------------------------------- synapse/config/workers.py | 6 ++ synapse/replication/http/devices.py | 67 ++++++++++++++++++++++ synapse/rest/client/keys.py | 68 ++++++++++++++++------ 7 files changed, 130 insertions(+), 127 deletions(-) create mode 100644 changelog.d/14400.misc (limited to 'synapse') diff --git a/changelog.d/14400.misc b/changelog.d/14400.misc new file mode 100644 index 0000000000..6e025329c4 --- /dev/null +++ b/changelog.d/14400.misc @@ -0,0 +1 @@ +Remove the `worker_main_http_uri` configuration setting. This is now handled via internal replication. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 62b1bab297..c1e1544536 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -213,10 +213,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "listener_resources": ["client", "replication"], "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"], "shared_extra_conf": {}, - "worker_extra_conf": ( - "worker_main_http_uri: http://127.0.0.1:%d" - % (MAIN_PROCESS_HTTP_LISTENER_PORT,) - ), + "worker_extra_conf": "", }, "account_data": { "app": "synapse.app.generic_worker", diff --git a/docs/workers.md b/docs/workers.md index 7ee8801161..4604650803 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -135,8 +135,8 @@ In the config file for each worker, you must specify: [`worker_replication_http_port`](usage/configuration/config_documentation.md#worker_replication_http_port)). * If handling HTTP requests, a [`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option with an `http` listener. - * If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for - the main process (`worker_main_http_uri`). + * **Synapse 1.71 and older:** if handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for + the main process (`worker_main_http_uri`). This config option is no longer required and is ignored when running Synapse 1.72 and newer. For example: @@ -221,7 +221,6 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Encryption requests - # Note that ^/_matrix/client/(r0|v3|unstable)/keys/upload/ requires `worker_main_http_uri` ^/_matrix/client/(r0|v3|unstable)/keys/query$ ^/_matrix/client/(r0|v3|unstable)/keys/changes$ ^/_matrix/client/(r0|v3|unstable)/keys/claim$ @@ -376,7 +375,7 @@ responsible for - persisting them to the DB, and finally - updating the events stream. -Because load is sharded in this way, you *must* restart all worker instances when +Because load is sharded in this way, you *must* restart all worker instances when adding or removing event persisters. An `event_persister` should not be mistaken for an `event_creator`. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1d9aef45c2..74909b7d4a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -14,14 +14,12 @@ # limitations under the License. import logging import sys -from typing import Dict, List, Optional, Tuple +from typing import Dict, List -from twisted.internet import address from twisted.web.resource import Resource import synapse import synapse.events -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.urls import ( CLIENT_API_PREFIX, FEDERATION_PREFIX, @@ -43,8 +41,6 @@ from synapse.config.logger import setup_logging from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -70,12 +66,12 @@ from synapse.rest.client import ( versions, voip, ) -from synapse.rest.client._base import client_patterns from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet from synapse.rest.client.devices import DevicesRestServlet from synapse.rest.client.keys import ( KeyChangesServlet, KeyQueryServlet, + KeyUploadServlet, OneTimeKeyServlet, ) from synapse.rest.client.register import ( @@ -132,107 +128,12 @@ from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION from synapse.util.httpresourcetree import create_resource_tree logger = logging.getLogger("synapse.app.generic_worker") -class KeyUploadServlet(RestServlet): - """An implementation of the `KeyUploadServlet` that responds to read only - requests, but otherwise proxies through to the master instance. - """ - - PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - - def __init__(self, hs: HomeServer): - """ - Args: - hs: server - """ - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker.worker_main_http_uri - - async def on_POST( - self, request: SynapseRequest, device_id: Optional[str] - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - if body: - # They're actually trying to upload something, proxy to main synapse. - - # Proxy headers from the original request, such as the auth headers - # (in case the access token is there) and the original IP / - # User-Agent of the request. - headers: Dict[bytes, List[bytes]] = { - header: list(request.requestHeaders.getRawHeaders(header, [])) - for header in (b"Authorization", b"User-Agent") - } - # Add the previous hop to the X-Forwarded-For header. - x_forwarded_for = list( - request.requestHeaders.getRawHeaders(b"X-Forwarded-For", []) - ) - # we use request.client here, since we want the previous hop, not the - # original client (as returned by request.getClientAddress()). - if isinstance(request.client, (address.IPv4Address, address.IPv6Address)): - previous_host = request.client.host.encode("ascii") - # If the header exists, add to the comma-separated list of the first - # instance of the header. Otherwise, generate a new header. - if x_forwarded_for: - x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host] - x_forwarded_for.extend(x_forwarded_for[1:]) - else: - x_forwarded_for = [previous_host] - headers[b"X-Forwarded-For"] = x_forwarded_for - - # Replicate the original X-Forwarded-Proto header. Note that - # XForwardedForRequest overrides isSecure() to give us the original protocol - # used by the client, as opposed to the protocol used by our upstream proxy - # - which is what we want here. - headers[b"X-Forwarded-Proto"] = [ - b"https" if request.isSecure() else b"http" - ] - - try: - result = await self.http_client.post_json_get_json( - self.main_uri + request.uri.decode("ascii"), body, headers=headers - ) - except HttpResponseException as e: - raise e.to_synapse_error() from e - except RequestSendFailed as e: - raise SynapseError(502, "Failed to talk to master") from e - - return 200, result - else: - # Just interested in counts. - result = await self.store.count_e2e_one_time_keys(user_id, device_id) - return 200, {"one_time_key_counts": result} - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 88b3168cbc..c4e2273a95 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -162,7 +162,13 @@ class WorkerConfig(Config): self.worker_name = config.get("worker_name", self.worker_app) self.instance_name = self.worker_name or "master" + # FIXME: Remove this check after a suitable amount of time. self.worker_main_http_uri = config.get("worker_main_http_uri", None) + if self.worker_main_http_uri is not None: + logger.warning( + "The config option worker_main_http_uri is unused since Synapse 1.72. " + "It can be safely removed from your configuration." + ) # This option is really only here to support `--manhole` command line # argument. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 3d63645726..c21629def8 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -78,5 +79,71 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices +class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): + """Ask master to upload keys for the user and send them out over federation to + update other servers. + + For now, only the master is permitted to handle key upload requests; + any worker can handle key query requests (since they're read-only). + + Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on + the main process to accomplish this. + + Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload + Request format(borrowed and expanded from KeyUploadServlet): + + POST /_synapse/replication/upload_keys_for_user + + { + "user_id": "", + "device_id": "", + "keys": { + ....this part can be found in KeyUploadServlet in rest/client/keys.py.... + } + } + + Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet + + """ + + NAME = "upload_keys_for_user" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: + + return { + "user_id": user_id, + "device_id": device_id, + "keys": keys, + } + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + content = parse_json_object_from_request(request) + + user_id = content["user_id"] + device_id = content["device_id"] + keys = content["keys"] + + results = await self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, keys + ) + + return 200, results + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index f653d2a3e1..ee038c7192 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -27,6 +27,7 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict, StreamToken from synapse.util.cancellation import cancellable @@ -43,24 +44,48 @@ class KeyUploadServlet(RestServlet): Content-Type: application/json { - "device_keys": { - "user_id": "", - "device_id": "", - "valid_until_ts": , - "algorithms": [ - "m.olm.curve25519-aes-sha2", - ] - "keys": { - ":": "", + "device_keys": { + "user_id": "", + "device_id": "", + "valid_until_ts": , + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + ":": "", + }, + "signatures:" { + "" { + ":": "" + } + } + }, + "fallback_keys": { + ":": "", + "signed_:": { + "fallback": true, + "key": "", + "signatures": { + "": { + ":": "" + } + } + } + } + "one_time_keys": { + ":": "" }, - "signatures:" { - "" { - ":": "" - } } }, - "one_time_keys": { - ":": "" - }, } + + response, e.g.: + + { + "one_time_key_counts": { + "curve25519": 10, + "signed_curve25519": 20 + } + } + """ PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") @@ -71,6 +96,13 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() + if hs.config.worker.worker_app is None: + # if main process + self.key_uploader = self.e2e_keys_handler.upload_keys_for_user + else: + # then a worker + self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) + async def on_POST( self, request: SynapseRequest, device_id: Optional[str] ) -> Tuple[int, JsonDict]: @@ -109,8 +141,8 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = await self.e2e_keys_handler.upload_keys_for_user( - user_id, device_id, body + result = await self.key_uploader( + user_id=user_id, device_id=device_id, keys=body ) return 200, result -- cgit 1.5.1 From 115f0eb2334b13665e5c112bd87f95ea393c9047 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 16 Nov 2022 22:16:46 +0000 Subject: Reintroduce #14376, with bugfix for monoliths (#14468) * Add tests for StreamIdGenerator * Drive-by: annotate all defs * Revert "Revert "Remove slaved id tracker (#14376)" (#14463)" This reverts commit d63814fd736fed5d3d45ff3af5e6d3bfae50c439, which in turn reverted 36097e88c4da51fce6556a58c49bd675f4cf20ab. This restores the latter. * Fix StreamIdGenerator not handling unpersisted IDs Spotted by @erikjohnston. Closes #14456. * Changelog Co-authored-by: Nick Mills-Barrett Co-authored-by: Erik Johnston --- changelog.d/14376.misc | 1 + changelog.d/14468.misc | 1 + mypy.ini | 3 + synapse/replication/slave/__init__.py | 13 -- synapse/replication/slave/storage/__init__.py | 13 -- .../slave/storage/_slaved_id_tracker.py | 50 ------- synapse/storage/databases/main/account_data.py | 30 ++-- synapse/storage/databases/main/devices.py | 36 ++--- synapse/storage/databases/main/events_worker.py | 35 ++--- synapse/storage/databases/main/push_rule.py | 17 +-- synapse/storage/databases/main/pusher.py | 24 ++- synapse/storage/databases/main/receipts.py | 18 +-- synapse/storage/util/id_generators.py | 13 +- tests/storage/test_id_generators.py | 162 +++++++++++++++++++-- 14 files changed, 230 insertions(+), 186 deletions(-) create mode 100644 changelog.d/14376.misc create mode 100644 changelog.d/14468.misc delete mode 100644 synapse/replication/slave/__init__.py delete mode 100644 synapse/replication/slave/storage/__init__.py delete mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'synapse') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/changelog.d/14468.misc b/changelog.d/14468.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14468.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/mypy.ini b/mypy.ini index 8f1141a239..53512b2584 100644 --- a/mypy.ini +++ b/mypy.ini @@ -117,6 +117,9 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True +[mypy-tests.storage.test_id_generators] +disallow_untyped_defs = True + [mypy-tests.storage.test_profile] disallow_untyped_defs = True diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e114c733d1..57230df5ae 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8a104f7e93..01e935edef 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index fbf27497ec..a580e4bdda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..0d7108f01b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if not self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 2d8d1f860f..d6a2b8d274 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -16,15 +16,157 @@ from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import IncorrectDatabaseSetup -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS +class StreamIdGeneratorTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") + + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + table="foobar", + column="stream_id", + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("master", 3) # Now we add a row *without* updating the stream ID - def _insert(txn): + def _insert(txn: Cursor) -> None: txn.execute("INSERT INTO foobar VALUES (26, 'master')") self.get_success(self.db_pool.runInteraction("_insert", _insert)) @@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name: str, number: int, update_stream_table: bool = True, - ): + ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ -- cgit 1.5.1 From 75888c2b1f5ec1c865c4690627bf101f7e0dffb9 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 17 Nov 2022 17:01:14 +0100 Subject: Faster joins: do not wait for full state when creating events to send (#14403) Signed-off-by: Mathieu Velten --- changelog.d/14403.misc | 1 + synapse/events/builder.py | 1 + synapse/state/__init__.py | 8 +++++++- 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14403.misc (limited to 'synapse') diff --git a/changelog.d/14403.misc b/changelog.d/14403.misc new file mode 100644 index 0000000000..ff28a2712a --- /dev/null +++ b/changelog.d/14403.misc @@ -0,0 +1 @@ +Faster joins: do not wait for full state when creating events to send. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index e2ee10dd3d..d62906043f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,6 +128,7 @@ class EventBuilder: state_filter=StateFilter.from_types( auth_types_for_event(self.room_version, self) ), + await_full_state=False, ) auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6f3dd0463e..833ffec3de 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -190,6 +190,7 @@ class StateHandler: room_id: str, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Fetch the state after each of the given event IDs. Resolve them and return. @@ -200,13 +201,18 @@ class StateHandler: Args: room_id: the room_id containing the given events. event_ids: the events whose state should be fetched and resolved. + await_full_state: if `True`, will block if we do not yet have complete state + at the given `event_id`s, regardless of whether `state_filter` is + satisfied by partial state. Returns: the state dict (a mapping from (event_type, state_key) -> event_id) which holds the resolution of the states after the given event IDs. """ logger.debug("calling resolve_state_groups from compute_state_after_events") - ret = await self.resolve_state_groups_for_events(room_id, event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, event_ids, await_full_state + ) return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( -- cgit 1.5.1 From e7132c3f81acbc50c1923cad7eeab96d3b2e05fd Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 17 Nov 2022 16:09:56 +0000 Subject: Fix check to ignore blank lines in incoming TCP replication (#14449) --- changelog.d/14449.misc | 1 + synapse/replication/tcp/protocol.py | 2 +- synapse/storage/database.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14449.misc (limited to 'synapse') diff --git a/changelog.d/14449.misc b/changelog.d/14449.misc new file mode 100644 index 0000000000..320c0b6fae --- /dev/null +++ b/changelog.d/14449.misc @@ -0,0 +1 @@ +Fix type logic in TCP replication code that prevented correctly ignoring blank commands. \ No newline at end of file diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7763ffb2d0..56a5c21910 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -245,7 +245,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._parse_and_dispatch_line(line) def _parse_and_dispatch_line(self, line: bytes) -> None: - if line.strip() == "": + if line.strip() == b"": # Ignore blank lines return diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4717c9728a..0dc44b246c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -569,15 +569,15 @@ class DatabasePool: retcols=["update_name"], desc="check_background_updates", ) - updates = [x["update_name"] for x in updates] + background_update_names = [x["update_name"] for x in updates] for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: + if update_name not in background_update_names: logger.debug("Now safe to upsert in %s", table) self._unsafe_to_upsert_tables.discard(table) # If there's any updates still running, reschedule to run. - if updates: + if background_update_names: self._clock.call_later( 15.0, run_as_background_process, -- cgit 1.5.1 From 01a052789266179c70c10ea6a6253c64fd9990d2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 17 Nov 2022 16:11:08 +0000 Subject: Fix version that `worker_main_http_uri` is redundant from (#14476) * Fix version that `worker_main_http_uri` is redundant from * Changelog --- changelog.d/14476.misc | 1 + docs/workers.md | 4 ++-- synapse/config/workers.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14476.misc (limited to 'synapse') diff --git a/changelog.d/14476.misc b/changelog.d/14476.misc new file mode 100644 index 0000000000..6e025329c4 --- /dev/null +++ b/changelog.d/14476.misc @@ -0,0 +1 @@ +Remove the `worker_main_http_uri` configuration setting. This is now handled via internal replication. diff --git a/docs/workers.md b/docs/workers.md index 4604650803..27e54c5846 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -135,8 +135,8 @@ In the config file for each worker, you must specify: [`worker_replication_http_port`](usage/configuration/config_documentation.md#worker_replication_http_port)). * If handling HTTP requests, a [`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option with an `http` listener. - * **Synapse 1.71 and older:** if handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for - the main process (`worker_main_http_uri`). This config option is no longer required and is ignored when running Synapse 1.72 and newer. + * **Synapse 1.72 and older:** if handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for + the main process (`worker_main_http_uri`). This config option is no longer required and is ignored when running Synapse 1.73 and newer. For example: diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c4e2273a95..913b83e174 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -166,7 +166,7 @@ class WorkerConfig(Config): self.worker_main_http_uri = config.get("worker_main_http_uri", None) if self.worker_main_http_uri is not None: logger.warning( - "The config option worker_main_http_uri is unused since Synapse 1.72. " + "The config option worker_main_http_uri is unused since Synapse 1.73. " "It can be safely removed from your configuration." ) -- cgit 1.5.1 From 78e23eea056cbf75b9478140f17699195dd490f2 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 18 Nov 2022 18:10:01 +0000 Subject: Reduce default third party invite rate limit to 216 invites per day (#14487) The previous default was the same as the `rc_message` rate limit, which defaults to 17,280 per day. Signed-off-by: Sean Quah --- changelog.d/14487.misc | 1 + synapse/config/ratelimiting.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14487.misc (limited to 'synapse') diff --git a/changelog.d/14487.misc b/changelog.d/14487.misc new file mode 100644 index 0000000000..f6b47a1d8e --- /dev/null +++ b/changelog.d/14487.misc @@ -0,0 +1 @@ +Reduce default third party invite rate limit to 216 invites per day. diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 1ed001e105..5c13fe428a 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -150,8 +150,5 @@ class RatelimitConfig(Config): self.rc_third_party_invite = RatelimitSettings( config.get("rc_third_party_invite", {}), - defaults={ - "per_second": self.rc_message.per_second, - "burst_count": self.rc_message.burst_count, - }, + defaults={"per_second": 0.0025, "burst_count": 5}, ) -- cgit 1.5.1 From e1b15f25f3ad4b45b381544ca6b3cd2caf43d25d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 18 Nov 2022 19:56:42 +0000 Subject: Fix /key/v2/server calls with URL-unsafe key IDs (#14490) Co-authored-by: Patrick Cloke --- changelog.d/14490.misc | 1 + synapse/crypto/keyring.py | 2 +- tests/crypto/test_keyring.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14490.misc (limited to 'synapse') diff --git a/changelog.d/14490.misc b/changelog.d/14490.misc new file mode 100644 index 0000000000..c0a4daa885 --- /dev/null +++ b/changelog.d/14490.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index dd9b8089ec..ed15f88350 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -857,7 +857,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): response = await self.client.get_json( destination=server_name, path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), + + urllib.parse.quote(requested_key_id, safe=""), ignore_backoff=True, # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 820a1a54e2..63628aa6b0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -469,6 +469,18 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) + def test_keyid_containing_forward_slash(self) -> None: + """We should url-encode any url unsafe chars in key ids. + + Detects https://github.com/matrix-org/synapse/issues/14488. + """ + fetcher = ServerKeyFetcher(self.hs) + self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) + + self.http_client.get_json.assert_called_once() + args, kwargs = self.http_client.get_json.call_args + self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") + class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): -- cgit 1.5.1 From 1526ff389f02d14d0df729bd6ea35836e758c449 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 21 Nov 2022 16:46:14 +0100 Subject: Faster joins: filter out non local events when a room doesn't have its full state (#14404) Signed-off-by: Mathieu Velten --- changelog.d/14404.misc | 1 + synapse/federation/sender/per_destination_queue.py | 1 + synapse/handlers/federation.py | 15 +++++++---- synapse/visibility.py | 29 +++++++++++++++++++--- tests/test_visibility.py | 10 ++++---- 5 files changed, 43 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14404.misc (limited to 'synapse') diff --git a/changelog.d/14404.misc b/changelog.d/14404.misc new file mode 100644 index 0000000000..b9ab525f2b --- /dev/null +++ b/changelog.d/14404.misc @@ -0,0 +1 @@ +Faster joins: filter out non local events when a room doesn't have its full state. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 084c45a95c..3ae5e8634c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -505,6 +505,7 @@ class PerDestinationQueue: new_pdus = await filter_events_for_server( self._storage_controllers, self._destination, + self._server_name, new_pdus, redact=False, ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 188f0956ef..d92582fd5c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -379,6 +379,7 @@ class FederationHandler: filtered_extremities = await filter_events_for_server( self._storage_controllers, self.server_name, + self.server_name, events_to_check, redact=False, check_history_visibility_only=True, @@ -1231,7 +1232,9 @@ class FederationHandler: async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Synapse asks for 100 events per backfill request. Do not allow more. limit = min(limit, 100) @@ -1252,7 +1255,7 @@ class FederationHandler: ) events = await filter_events_for_server( - self._storage_controllers, origin, events + self._storage_controllers, origin, self.server_name, events ) return events @@ -1283,7 +1286,7 @@ class FederationHandler: await self._event_auth_handler.assert_host_in_room(event.room_id, origin) events = await filter_events_for_server( - self._storage_controllers, origin, [event] + self._storage_controllers, origin, self.server_name, [event] ) event = events[0] return event @@ -1296,7 +1299,9 @@ class FederationHandler: latest_events: List[str], limit: int, ) -> List[EventBase]: - await self._event_auth_handler.assert_host_in_room(room_id, origin) + # We allow partially joined rooms since in this case we are filtering out + # non-local events in `filter_events_for_server`. + await self._event_auth_handler.assert_host_in_room(room_id, origin, True) # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) @@ -1309,7 +1314,7 @@ class FederationHandler: ) missing_events = await filter_events_for_server( - self._storage_controllers, origin, missing_events + self._storage_controllers, origin, self.server_name, missing_events ) return missing_events diff --git a/synapse/visibility.py b/synapse/visibility.py index 40a9c5b53f..b443857571 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -563,7 +563,8 @@ def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: async def filter_events_for_server( storage: StorageControllers, - server_name: str, + target_server_name: str, + local_server_name: str, events: List[EventBase], redact: bool = True, check_history_visibility_only: bool = False, @@ -603,7 +604,7 @@ async def filter_events_for_server( # if the server is either in the room or has been invited # into the room. for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == server_name + assert get_domain_from_id(ev.state_key) == target_server_name memtype = ev.membership if memtype == Membership.JOIN: @@ -622,6 +623,24 @@ async def filter_events_for_server( # to no users having been erased. erased_senders = {} + # Filter out non-local events when we are in the middle of a partial join, since our servers + # list can be out of date and we could leak events to servers not in the room anymore. + # This can also be true for local events but we consider it to be an acceptable risk. + + # We do this check as a first step and before retrieving membership events because + # otherwise a room could be fully joined after we retrieve those, which would then bypass + # this check but would base the filtering on an outdated view of the membership events. + + partial_state_invisible_events = set() + if not check_history_visibility_only: + for e in events: + sender_domain = get_domain_from_id(e.sender) + if ( + sender_domain != local_server_name + and await storage.main.is_partial_state_room(e.room_id) + ): + partial_state_invisible_events.add(e) + # Let's check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). @@ -636,7 +655,7 @@ async def filter_events_for_server( if event_to_history_vis[e.event_id] not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) ], - server_name, + target_server_name, ) to_return = [] @@ -645,6 +664,10 @@ async def filter_events_for_server( visible = check_event_is_visible( event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) ) + + if e in partial_state_invisible_events: + visible = False + if visible and not erased: to_return.append(e) elif redact: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index c385b2f8d4..d0b9ad5454 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -61,7 +61,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "hs", events_to_filter ) ) @@ -83,7 +83,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier] + self._storage_controllers, "remote_hs", "hs", [outlier] ) ), [outlier], @@ -94,7 +94,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "remote_hs", [outlier, evt] + self._storage_controllers, "remote_hs", "local_hs", [outlier, evt] ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") @@ -106,7 +106,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # be redacted) filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "other_server", [outlier, evt] + self._storage_controllers, "other_server", "local_hs", [outlier, evt] ) ) self.assertEqual(filtered[0], outlier) @@ -141,7 +141,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( filter_events_for_server( - self._storage_controllers, "test_server", events_to_filter + self._storage_controllers, "test_server", "local_hs", events_to_filter ) ) -- cgit 1.5.1 From 1799a54a545618782840a60950ef4b64da9ee24d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 07:26:11 -0500 Subject: Batch fetch bundled annotations (#14491) Avoid an n+1 query problem and fetch the bundled aggregations for m.annotation relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660) and threads in b65acead428653b988351ae8d7b22127a22039cd (#11752). --- changelog.d/14491.feature | 1 + synapse/handlers/relations.py | 197 ++++++++++++++++------------ synapse/storage/databases/main/relations.py | 139 ++++++++++++-------- synapse/util/caches/descriptors.py | 2 +- tests/rest/client/test_relations.py | 4 +- 5 files changed, 202 insertions(+), 141 deletions(-) create mode 100644 changelog.d/14491.feature (limited to 'synapse') diff --git a/changelog.d/14491.feature b/changelog.d/14491.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14491.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e71dda970..ca94239f61 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,7 +13,16 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, +) import attr @@ -259,48 +268,64 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_event( - self, - event_id: str, - room_id: str, - limit: int = 5, - ignored_users: FrozenSet[str] = frozenset(), - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + async def get_annotations_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[JsonDict]]: + """Get a list of annotations to the given events, grouped by event type and aggregation key, sorted by count. - This is used e.g. to get the what and how many reactions have happend + This is used e.g. to get the what and how many reactions have happened on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. ignored_users: The users ignored by the requesting user. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_event( - event_id, room_id, limit + full_results = await self._main_store.get_aggregation_groups_for_events( + event_ids ) + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in full_results.items() + if results + } + # Then subtract off the results for any ignored users. ignored_results = await self._main_store.get_aggregation_groups_for_users( - event_id, room_id, limit, ignored_users + [event_id for event_id, results in full_results.items() if results], + ignored_users, ) - filtered_results = [] - for result in full_results: - key = (result["type"], result["key"]) - if key in ignored_results: - result = result.copy() - result["count"] -= ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.append(result) + filtered_results = {} + for event_id, results in full_results.items(): + # If no annotations, skip. + if not results: + continue + + # If there are not ignored results for this event, copy verbatim. + if event_id not in ignored_results: + filtered_results[event_id] = results + continue + + # Otherwise, subtract out the ignored results. + event_ignored_results = ignored_results[event_id] + for result in results: + key = (result["type"], result["key"]) + if key in event_ignored_results: + # Ensure to not modify the cache. + result = result.copy() + result["count"] -= event_ignored_results[key] + if result["count"] <= 0: + continue + filtered_results.setdefault(event_id, []).append(result) return filtered_results @@ -366,59 +391,62 @@ class RelationsHandler: results = {} for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event = summary - - # Subtract off the count of any ignored users. - for ignored_user in ignored_users: - thread_count -= ignored_results.get((event_id, ignored_user), 0) - - # This is gnarly, but if the latest event is from an ignored user, - # attempt to find one that isn't from an ignored user. - if latest_thread_event.sender in ignored_users: - room_id = latest_thread_event.room_id - - # If the root event is not found, something went wrong, do - # not include a summary of the thread. - event = await self._event_handler.get_event(user, room_id, event_id) - if event is None: - continue + # If no thread, skip. + if not summary: + continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, - ) + thread_count, latest_thread_event = summary - # If all found events are from ignored users, do not include - # a summary of the thread. - if not potential_events: - continue + # Subtract off the count of any ignored users. + for ignored_user in ignored_users: + thread_count -= ignored_results.get((event_id, ignored_user), 0) - # The *last* event returned is the one that is cared about. - event = await self._event_handler.get_event( - user, room_id, potential_events[-1].event_id - ) - # It is unexpected that the event will not exist. - if event is None: - logger.warning( - "Unable to fetch latest event in a thread with event ID: %s", - potential_events[-1].event_id, - ) - continue - latest_thread_event = event - - results[event_id] = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=events_by_id[event_id].sender == user_id - or participated[event_id], + # This is gnarly, but if the latest event is from an ignored user, + # attempt to find one that isn't from an ignored user. + if latest_thread_event.sender in ignored_users: + room_id = latest_thread_event.room_id + + # If the root event is not found, something went wrong, do + # not include a summary of the thread. + event = await self._event_handler.get_event(user, room_id, event_id) + if event is None: + continue + + potential_events, _ = await self.get_relations_for_event( + event_id, + event, + room_id, + RelationTypes.THREAD, + ignored_users, ) + # If all found events are from ignored users, do not include + # a summary of the thread. + if not potential_events: + continue + + # The *last* event returned is the one that is cared about. + event = await self._event_handler.get_event( + user, room_id, potential_events[-1].event_id + ) + # It is unexpected that the event will not exist. + if event is None: + logger.warning( + "Unable to fetch latest event in a thread with event ID: %s", + potential_events[-1].event_id, + ) + continue + latest_thread_event = event + + results[event_id] = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], + ) + return results @trace @@ -496,17 +524,18 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any annotations (ie, reactions) to bundle with this event. - annotations = await self.get_annotations_for_event( - event.event_id, event.room_id, ignored_users=ignored_users - ) + # Fetch any annotations (ie, reactions) to bundle with this event. + annotations_by_event_id = await self.get_annotations_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, annotations in annotations_by_event_id.items(): if annotations: - results.setdefault( - event.event_id, BundledAggregations() - ).annotations = {"chunk": annotations} + results.setdefault(event_id, BundledAggregations()).annotations = { + "chunk": annotations + } + # Fetch other relations per event. + for event in events_by_id.values(): # Fetch any references to bundle with this event. references, next_token = await self.get_relations_for_event( event.event_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ca431002c8..f96a16956a 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,6 +20,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached(tree=True) - async def get_aggregation_groups_for_event( - self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + @cached() + async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" + ) + async def get_aggregation_groups_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[JsonDict]]]: + """Get a list of annotations on the given events, grouped by event type and aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ + # The number of entries to return per event ID. + limit = 5 - args = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - limit, - ] + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.ANNOTATION) - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + sql = f""" + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE + {clause} + AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_events_txn( txn: LoggingTransaction, - ) -> List[JsonDict]: + ) -> Mapping[str, List[JsonDict]]: txn.execute(sql, args) - return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + result: Dict[str, List[JsonDict]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + event_results = result.setdefault(event_id, []) + + # Limit the number of results per event ID. + if len(event_results) == limit: + continue + + event_results.append({"type": type, "key": key, "count": count}) + + return result return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn ) async def get_aggregation_groups_for_users( - self, - event_id: str, - room_id: str, - limit: int, - users: FrozenSet[str] = frozenset(), - ) -> Dict[Tuple[str, str], int]: + self, event_ids: Collection[str], users: FrozenSet[str] + ) -> Dict[str, Dict[Tuple[str, str], int]]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. users: The users to fetch information for. Returns: - A map of (event type, aggregation key) to a count of users. + A map of event ID to a map of (event type, aggregation key) to a + count of users. """ if not users: return {} - args: List[Union[str, int]] = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - ] + events_sql, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "sender", users + self.database_engine, "annotation.sender", users ) args.extend(users_args) + args.append(RelationTypes.ANNOTATION) sql = f""" - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE {events_sql} AND {users_sql} AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], int]: - txn.execute(sql, args + [limit]) + ) -> Dict[str, Dict[Tuple[str, str], int]]: + txn.execute(sql, args) - return {(row[0], row[1]): row[2] for row in txn} + result: Dict[str, Dict[Tuple[str, str], int]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + result.setdefault(event_id, {})[(type, key)] = count + + return result return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 75428d19ba..72227359b9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -503,7 +503,7 @@ def cachedList( is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in the cache gets passed to the original function, which is expected to results - in a map of key to value for each passed value. THe new results are stored in the + in a map of key to value for each passed value. The new results are stored in the original cache. Note that any missing values are cached as None. Args: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index e3d801f7a8..2d2b683548 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 6d7523ef1484ec56f4a6dffdd2ea3d8736b4cc98 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 09:41:09 -0500 Subject: Batch fetch bundled references (#14508) Avoid an n+1 query problem and fetch the bundled aggregations for m.reference relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660; threads in b65acead428653b988351ae8d7b22127a22039cd (#11752); and annotations in 1799a54a545618782840a60950ef4b64da9ee24d (#14491). --- changelog.d/14508.feature | 1 + synapse/handlers/relations.py | 128 +++++++++++++--------------- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 4 + synapse/storage/databases/main/relations.py | 74 ++++++++++++++-- tests/rest/client/test_relations.py | 4 +- 6 files changed, 133 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14508.feature (limited to 'synapse') diff --git a/changelog.d/14508.feature b/changelog.d/14508.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14508.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ca94239f61..8414be5879 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,16 +13,7 @@ # limitations under the License. import enum import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional import attr @@ -32,7 +23,7 @@ from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -181,40 +172,6 @@ class RelationsHandler: return return_value - async def get_relations_for_event( - self, - event_id: str, - event: EventBase, - room_id: str, - relation_type: str, - ignored_users: FrozenSet[str] = frozenset(), - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - """Get a list of events which relate to an event, ordered by topological ordering. - - Args: - event_id: Fetch events that relate to this event ID. - event: The matching EventBase to event_id. - room_id: The room the event belongs to. - relation_type: The type of relation. - ignored_users: The users ignored by the requesting user. - - Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. - """ - - # Call the underlying storage method, which is cached. - related_events, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, relation_type, direction="f" - ) - - # Filter out ignored users and convert to the expected format. - related_events = [ - event for event in related_events if event.sender not in ignored_users - ] - - return related_events, next_token - async def redact_events_related_to( self, requester: Requester, @@ -329,6 +286,46 @@ class RelationsHandler: return filtered_results + async def get_references_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[_RelatedEvent]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to this event ID. + ignored_users: The users ignored by the requesting user. + + Returns: + A map of event IDs to a list related events. + """ + + related_events = await self._main_store.get_references_for_events(event_ids) + + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in related_events.items() + if results + } + + # Filter out ignored users. + results = {} + for event_id, events in related_events.items(): + # If no references, skip. + if not events: + continue + + # Filter ignored users out. + events = [event for event in events if event.sender not in ignored_users] + # If there are no events left, skip this event. + if not events: + continue + + results[event_id] = events + + return results + async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], @@ -412,14 +409,18 @@ class RelationsHandler: if event is None: continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, + # Attempt to find another event to use as the latest event. + potential_events, _ = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.THREAD, direction="f" ) + # Filter out ignored users. + potential_events = [ + event + for event in potential_events + if event.sender not in ignored_users + ] + # If all found events are from ignored users, do not include # a summary of the thread. if not potential_events: @@ -534,27 +535,16 @@ class RelationsHandler: "chunk": annotations } - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any references to bundle with this event. - references, next_token = await self.get_relations_for_event( - event.event_id, - event, - event.room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, - ) + # Fetch any references to bundle with this event. + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): if references: - aggregations = results.setdefault(event.event_id, BundledAggregations()) - aggregations.references = { + results.setdefault(event_id, BundledAggregations()).references = { "chunk": [{"event_id": ev.event_id} for ev in references] } - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - # Fetch any edits (but not for redacted events). # # Note that there is no use in limiting edits by ignored users since the @@ -600,7 +590,7 @@ class RelationsHandler: room_id, requester, allow_departed_users=True ) - # Note that ignored users are not passed into get_relations_for_event + # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). thread_roots, next_batch = await self._main_store.get_threads( diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ddb7397714..a58668a380 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d68f127f9b..0f097a2927 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2049,6 +2049,10 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) + ) if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index f96a16956a..aea96e9d24 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -82,8 +82,6 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore): txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore): # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -530,6 +534,60 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) + @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2d2b683548..b86f341ff5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 7eb74600423e00c6982493eed18551d7f294140d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 09:47:32 -0500 Subject: Parallelize calls to fetch bundled aggregations. (#14510) The bundled aggregations for annotations, references, and edits can be parallelized. --- changelog.d/14510.feature | 1 + synapse/handlers/relations.py | 83 ++++++++++++++++++++++++++----------------- 2 files changed, 52 insertions(+), 32 deletions(-) create mode 100644 changelog.d/14510.feature (limited to 'synapse') diff --git a/changelog.d/14510.feature b/changelog.d/14510.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14510.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8414be5879..e96f9999a8 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -20,10 +20,12 @@ import attr from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, UserID +from synapse.util.async_helpers import gather_results from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -525,39 +527,56 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - # Fetch any annotations (ie, reactions) to bundle with this event. - annotations_by_event_id = await self.get_annotations_for_events( - events_by_id.keys(), ignored_users=ignored_users - ) - for event_id, annotations in annotations_by_event_id.items(): - if annotations: - results.setdefault(event_id, BundledAggregations()).annotations = { - "chunk": annotations - } - - # Fetch any references to bundle with this event. - references_by_event_id = await self.get_references_for_events( - events_by_id.keys(), ignored_users=ignored_users - ) - for event_id, references in references_by_event_id.items(): - if references: - results.setdefault(event_id, BundledAggregations()).references = { - "chunk": [{"event_id": ev.event_id} for ev in references] - } - - # Fetch any edits (but not for redacted events). - # - # Note that there is no use in limiting edits by ignored users since the - # parent event should be ignored in the first place if the user is ignored. - edits = await self._main_store.get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] + async def _fetch_annotations() -> None: + """Fetch any annotations (ie, reactions) to bundle with this event.""" + annotations_by_event_id = await self.get_annotations_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, annotations in annotations_by_event_id.items(): + if annotations: + results.setdefault(event_id, BundledAggregations()).annotations = { + "chunk": annotations + } + + async def _fetch_references() -> None: + """Fetch any references to bundle with this event.""" + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): + if references: + results.setdefault(event_id, BundledAggregations()).references = { + "chunk": [{"event_id": ev.event_id} for ev in references] + } + + async def _fetch_edits() -> None: + """ + Fetch any edits (but not for redacted events). + + Note that there is no use in limiting edits by ignored users since the + parent event should be ignored in the first place if the user is ignored. + """ + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Parallelize the calls for annotations, references, and edits since they + # are unrelated. + await make_deferred_yieldable( + gather_results( + ( + run_in_background(_fetch_annotations), + run_in_background(_fetch_references), + run_in_background(_fetch_edits), + ) + ) ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit return results -- cgit 1.5.1 From 9cae44f49e6bf4f6b8a20ab11a65da417bb1565f Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 22 Nov 2022 16:46:52 +0000 Subject: Track unconverted device list outbound pokes using a position instead (#14516) When a local device list change is added to `device_lists_changes_in_room`, the `converted_to_destinations` flag is set to `FALSE` and the `_handle_new_device_update_async` background process is started. This background process looks for unconverted rows in `device_lists_changes_in_room`, copies them to `device_lists_outbound_pokes` and updates the flag. To update the `converted_to_destinations` flag, the database performs a `DELETE` and `INSERT` internally, which fragments the table. To avoid this, track unconverted rows using a `(stream ID, room ID)` position instead of the flag. From now on, the `converted_to_destinations` column indicates rows that need converting to outbound pokes, but does not indicate whether the conversion has already taken place. Closes #14037. Signed-off-by: Sean Quah --- changelog.d/14516.misc | 1 + synapse/handlers/device.py | 30 +++++- synapse/storage/database.py | 13 +-- synapse/storage/databases/main/devices.py | 107 +++++++++++++-------- .../73/12refactor_device_list_outbound_pokes.sql | 53 ++++++++++ tests/storage/test_devices.py | 3 +- 6 files changed, 158 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14516.misc create mode 100644 synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql (limited to 'synapse') diff --git a/changelog.d/14516.misc b/changelog.d/14516.misc new file mode 100644 index 0000000000..51666c6ffc --- /dev/null +++ b/changelog.d/14516.misc @@ -0,0 +1 @@ +Refactor conversion of device list changes in room to outbound pokes to track unconverted rows using a `(stream ID, room ID)` position instead of updating the `converted_to_destinations` flag on every row. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c597639a7f..da3ddafeae 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -682,13 +682,33 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to: Set[str] = set() try: + stream_id, room_id = await self.store.get_device_change_last_converted_pos() + while True: self._handle_new_device_update_new_data = False - rows = await self.store.get_uncoverted_outbound_room_pokes() + max_stream_id = self.store.get_device_stream_token() + rows = await self.store.get_uncoverted_outbound_room_pokes( + stream_id, room_id + ) if not rows: # If the DB returned nothing then there is nothing left to # do, *unless* a new device list update happened during the # DB query. + + # Advance `(stream_id, room_id)`. + # `max_stream_id` comes from *before* the query for unconverted + # rows, which means that any unconverted rows must have a larger + # stream ID. + if max_stream_id > stream_id: + stream_id, room_id = max_stream_id, "" + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + else: + assert max_stream_id == stream_id + # Avoid moving `room_id` backwards. + pass + if self._handle_new_device_update_new_data: continue else: @@ -718,7 +738,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=stream_id, hosts=hosts, context=opentracing_context, ) @@ -752,6 +771,12 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to.update(hosts) current_stream_id = stream_id + # Advance `(stream_id, room_id)`. + _, _, room_id, stream_id, _ = rows[-1] + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + finally: self._handle_new_device_update_is_processing = False @@ -834,7 +859,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=None, hosts=potentially_changed_hosts, context=None, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0dc44b246c..a14b13aec8 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2075,13 +2075,14 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) + select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + + if keyvalues: + select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),) + txn.execute(select_sql, list(keyvalues.values())) + else: + txn.execute(select_sql) - txn.execute(select_sql, list(keyvalues.values())) row = txn.fetchone() if not row: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 57230df5ae..37629115ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -2008,27 +2008,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def get_uncoverted_outbound_room_pokes( - self, limit: int = 10 + self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. + Args: + start_stream_id: Together with `start_room_id`, indicates the position after + which to return device list changes. + start_room_id: Together with `start_stream_id`, indicates the position after + which to return device list changes. + limit: The maximum number of device list changes to return. + Returns: - A list of user ID, device ID, room ID, stream ID and optional opentracing context. + A list of user ID, device ID, room ID, stream ID and optional opentracing + context, in order of ascending (stream ID, room ID). """ sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context FROM device_lists_changes_in_room - WHERE NOT converted_to_destinations - ORDER BY stream_id + WHERE + (stream_id, room_id) > (?, ?) AND + stream_id <= ? AND + NOT converted_to_destinations + ORDER BY stream_id ASC, room_id ASC LIMIT ? """ def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: - txn.execute(sql, (limit,)) + txn.execute( + sql, + ( + start_stream_id, + start_room_id, + # Avoid returning rows if there may be uncommitted device list + # changes with smaller stream IDs. + self._device_list_id_gen.get_current_token(), + limit, + ), + ) return [ ( @@ -2050,49 +2071,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - - Marks the associated row in `device_lists_changes_in_room` as handled, - if `stream_id` is provided. """ + if not hosts: + return def add_device_list_outbound_pokes_txn( txn: LoggingTransaction, stream_ids: List[int] ) -> None: - if hosts: - self._add_device_outbound_poke_to_stream_txn( - txn, - user_id=user_id, - device_id=device_id, - hosts=hosts, - stream_ids=stream_ids, - context=context, - ) - - if stream_id: - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) - - if not hosts: - # If there are no hosts then we don't try and generate stream IDs. - return await self.db_pool.runInteraction( - "add_device_list_outbound_pokes", - add_device_list_outbound_pokes_txn, - [], + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_id=device_id, + hosts=hosts, + stream_ids=stream_ids, + context=context, ) async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: @@ -2156,3 +2153,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "get_pending_remote_device_list_updates_for_room", get_pending_remote_device_list_updates_for_room_txn, ) + + async def get_device_change_last_converted_pos(self) -> Tuple[int, str]: + """ + Get the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + + Rows with a strictly greater position where `converted_to_destinations` is + `FALSE` have not been converted. + """ + + row = await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ) + return row["stream_id"], row["room_id"] + + async def set_device_change_last_converted_pos( + self, + stream_id: int, + room_id: str, + ) -> None: + """ + Set the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + """ + + await self.db_pool.simple_update_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + updatevalues={"stream_id": stream_id, "room_id": room_id}, + desc="set_device_change_last_converted_pos", + ) diff --git a/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql new file mode 100644 index 0000000000..93d7fcb79b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql @@ -0,0 +1,53 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Prior to this schema delta, we tracked the set of unconverted rows in +-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows +-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag +-- would be set. +-- +-- After this schema delta, the `converted_to_destinations` is still populated like +-- before, but the set of unconverted rows is determined by the `stream_id` in the new +-- `device_lists_changes_converted_stream_position` table. +-- +-- If rolled back, Synapse will re-send all device list changes that happened since the +-- schema delta. + +CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + -- The (stream id, room id) of the last row in `device_lists_changes_in_room` that + -- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger + -- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been + -- converted. + stream_id BIGINT NOT NULL, + -- `room_id` may be an empty string, which compares less than all valid room IDs. + room_id TEXT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES ( + ( + SELECT COALESCE( + -- The last converted stream id is the smallest unconverted stream id minus + -- one. + MIN(stream_id) - 1, + -- If there is no unconverted stream id, the last converted stream id is the + -- largest stream id. + -- Otherwise, pick 1, since stream ids start at 2. + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room) + ) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations + ), + '' +); diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f37505b6cf..8e7db2c4ec 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(HomeserverTestCase): """ for device_id in device_ids: - stream_id = self.get_success( + self.get_success( self.store.add_device_change_to_streams( user_id, [device_id], ["!some:room"] ) @@ -39,7 +39,6 @@ class DeviceStoreTestCase(HomeserverTestCase): user_id=user_id, device_id=device_id, room_id="!some:room", - stream_id=stream_id, hosts=[host], context={}, ) -- cgit 1.5.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'synapse') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1 From df390a8e676f514f3deecdcc2d12a6cc6b9e8e1d Mon Sep 17 00:00:00 2001 From: realtyem Date: Tue, 22 Nov 2022 15:33:58 -0600 Subject: Refactor `federation_sender` and `pusher` configuration loading. (#14496) To avoid duplicating the same logic for handling legacy configuration settings. This should help in applying similar logic to other worker types. --- changelog.d/14496.misc | 1 + synapse/config/workers.py | 139 +++++++++++++++++++++++----------------------- 2 files changed, 71 insertions(+), 69 deletions(-) create mode 100644 changelog.d/14496.misc (limited to 'synapse') diff --git a/changelog.d/14496.misc b/changelog.d/14496.misc new file mode 100644 index 0000000000..57fc6cf452 --- /dev/null +++ b/changelog.d/14496.misc @@ -0,0 +1 @@ +Refactor `federation_sender` and `pusher` configuration loading. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 913b83e174..2580660b6c 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -29,20 +29,6 @@ from ._base import ( ) from .server import DIRECT_TCP_ERROR, ListenerConfig, parse_listener_def -_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """ -The send_federation config option must be disabled in the main -synapse process before they can be run in a separate worker. - -Please add ``send_federation: false`` to the main config -""" - -_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """ -The start_pushers config option must be disabled in the main -synapse process before they can be run in a separate worker. - -Please add ``start_pushers: false`` to the main config -""" - _DEPRECATED_WORKER_DUTY_OPTION_USED = """ The '%s' configuration option is deprecated and will be removed in a future Synapse version. Please use ``%s: name_of_worker`` instead. @@ -182,40 +168,12 @@ class WorkerConfig(Config): ) ) - # Handle federation sender configuration. - # - # There are two ways of configuring which instances handle federation - # sending: - # 1. The old way where "send_federation" is set to false and running a - # `synapse.app.federation_sender` worker app. - # 2. Specifying the workers sending federation in - # `federation_sender_instances`. - # - - send_federation = config.get("send_federation", True) - - federation_sender_instances = config.get("federation_sender_instances") - if federation_sender_instances is None: - # Default to an empty list, which means "another, unknown, worker is - # responsible for it". - federation_sender_instances = [] - - # If no federation sender instances are set we check if - # `send_federation` is set, which means use master - if send_federation: - federation_sender_instances = ["master"] - - if self.worker_app == "synapse.app.federation_sender": - if send_federation: - # If we're running federation senders, and not using - # `federation_sender_instances`, then we should have - # explicitly set `send_federation` to false. - raise ConfigError( - _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR - ) - - federation_sender_instances = [self.worker_name] - + federation_sender_instances = self._worker_names_performing_this_duty( + config, + "send_federation", + "synapse.app.federation_sender", + "federation_sender_instances", + ) self.send_federation = self.instance_name in federation_sender_instances self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances @@ -282,27 +240,12 @@ class WorkerConfig(Config): ) # Handle sharded push - start_pushers = config.get("start_pushers", True) - pusher_instances = config.get("pusher_instances") - if pusher_instances is None: - # Default to an empty list, which means "another, unknown, worker is - # responsible for it". - pusher_instances = [] - - # If no pushers instances are set we check if `start_pushers` is - # set, which means use master - if start_pushers: - pusher_instances = ["master"] - - if self.worker_app == "synapse.app.pusher": - if start_pushers: - # If we're running pushers, and not using - # `pusher_instances`, then we should have explicitly set - # `start_pushers` to false. - raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR) - - pusher_instances = [self.instance_name] - + pusher_instances = self._worker_names_performing_this_duty( + config, + "start_pushers", + "synapse.app.pusher", + "pusher_instances", + ) self.start_pushers = self.instance_name in pusher_instances self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) @@ -425,6 +368,64 @@ class WorkerConfig(Config): # (By this point, these are either the same value or only one is not None.) return bool(new_option_should_run_here or legacy_option_should_run_here) + def _worker_names_performing_this_duty( + self, + config: Dict[str, Any], + legacy_option_name: str, + legacy_app_name: str, + modern_instance_list_name: str, + ) -> List[str]: + """ + Retrieves the names of the workers handling a given duty, by either legacy + option or instance list. + + There are two ways of configuring which instances handle a given duty, e.g. + for configuring pushers: + + 1. The old way where "start_pushers" is set to false and running a + `synapse.app.pusher'` worker app. + 2. Specifying the workers sending federation in `pusher_instances`. + + Args: + config: settings read from yaml. + legacy_option_name: the old way of enabling options. e.g. 'start_pushers' + legacy_app_name: The historical app name. e.g. 'synapse.app.pusher' + modern_instance_list_name: the string name of the new instance_list. e.g. + 'pusher_instances' + + Returns: + A list of worker instance names handling the given duty. + """ + + legacy_option = config.get(legacy_option_name, True) + + worker_instances = config.get(modern_instance_list_name) + if worker_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + worker_instances = [] + + # If no worker instances are set we check if the legacy option + # is set, which means use the main process. + if legacy_option: + worker_instances = ["master"] + + if self.worker_app == legacy_app_name: + if legacy_option: + # If we're using `legacy_app_name`, and not using + # `modern_instance_list_name`, then we should have + # explicitly set `legacy_option_name` to false. + raise ConfigError( + f"The '{legacy_option_name}' config option must be disabled in " + "the main synapse process before they can be run in a separate " + "worker.\n" + f"Please add `{legacy_option_name}: false` to the main config.\n", + ) + + worker_instances = [self.worker_name] + + return worker_instances + def read_arguments(self, args: argparse.Namespace) -> None: # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running -- cgit 1.5.1 From 7f78b383ca666c7f49a99b6c5095becb4ed7f1f4 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 22 Nov 2022 15:56:28 -0600 Subject: Optimize `filter_events_for_client` for faster `/messages` - v2 (#14527) Fix #14108 --- changelog.d/14527.misc | 1 + synapse/storage/databases/state/bg_updates.py | 99 +++++++++++++++++++++------ 2 files changed, 80 insertions(+), 20 deletions(-) create mode 100644 changelog.d/14527.misc (limited to 'synapse') diff --git a/changelog.d/14527.misc b/changelog.d/14527.misc new file mode 100644 index 0000000000..3c4c7bf07d --- /dev/null +++ b/changelog.d/14527.misc @@ -0,0 +1 @@ +Speed-up `/messages` with `filter_events_for_client` optimizations. diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index a7fcc564a9..4a4ad0f492 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -93,13 +93,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in @@ -110,31 +103,91 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): # against `state_groups_state` to fetch the latest state. # It assumes that previous state groups are always numerically # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. sql = """ - WITH RECURSIVE state(state_group) AS ( + WITH RECURSIVE sgs(state_group) AS ( VALUES(?::bigint) UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s + SELECT prev_state_group FROM state_group_edges e, sgs s WHERE s.state_group = e.state_group ) - SELECT DISTINCT ON (type, state_key) - type, state_key, event_id - FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) %s - ORDER BY type, state_key, state_group DESC + %s """ + overall_select_query_args: List[Union[int, str]] = [] + + # This is an optimization to create a select clause per-condition. This + # makes the query planner a lot smarter on what rows should pull out in the + # first place and we end up with something that takes 10x less time to get a + # result. + use_condition_optimization = ( + not state_filter.include_others and not state_filter.is_full() + ) + state_filter_condition_combos: List[Tuple[str, Optional[str]]] = [] + # We don't need to caclculate this list if we're not using the condition + # optimization + if use_condition_optimization: + for etype, state_keys in state_filter.types.items(): + if state_keys is None: + state_filter_condition_combos.append((etype, None)) + else: + for state_key in state_keys: + state_filter_condition_combos.append((etype, state_key)) + # And here is the optimization itself. We don't want to do the optimization + # if there are too many individual conditions. 10 is an arbitrary number + # with no testing behind it but we do know that we specifically made this + # optimization for when we grab the necessary state out for + # `filter_events_for_client` which just uses 2 conditions + # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`). + if use_condition_optimization and len(state_filter_condition_combos) < 10: + select_clause_list: List[str] = [] + for etype, skey in state_filter_condition_combos: + if skey is None: + where_clause = "(type = ?)" + overall_select_query_args.extend([etype]) + else: + where_clause = "(type = ? AND state_key = ?)" + overall_select_query_args.extend([etype, skey]) + + select_clause_list.append( + f""" + ( + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + INNER JOIN sgs USING (state_group) + WHERE {where_clause} + ORDER BY type, state_key, state_group DESC + ) + """ + ) + + overall_select_clause = " UNION ".join(select_clause_list) + else: + where_clause, where_args = state_filter.make_sql_filter_clause() + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + overall_select_query_args.extend(where_args) + + overall_select_clause = f""" + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM sgs + ) {where_clause} + ORDER BY type, state_key, state_group DESC + """ + for group in groups: args: List[Union[int, str]] = [group] - args.extend(where_args) + args.extend(overall_select_query_args) - txn.execute(sql % (where_clause,), args) + txn.execute(sql % (overall_select_clause,), args) for row in txn: typ, state_key, event_id = row key = (intern_string(typ), intern_string(state_key)) @@ -142,6 +195,12 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): else: max_entries_returned = state_filter.max_entries_returned() + where_clause, where_args = state_filter.make_sql_filter_clause() + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: -- cgit 1.5.1 From f38d7d79c8ec5c389c51327737bd517a27826bd6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Nov 2022 14:09:00 +0000 Subject: Add another index to `device_lists_changes_in_room` (#14534) This helps avoid reading unnecessarily large amounts of data from the table when querying with a set of room IDs. --- changelog.d/14534.misc | 1 + synapse/storage/databases/main/devices.py | 7 +++++++ .../main/delta/73/13add_device_lists_index.sql | 20 ++++++++++++++++++++ 3 files changed, 28 insertions(+) create mode 100644 changelog.d/14534.misc create mode 100644 synapse/storage/schema/main/delta/73/13add_device_lists_index.sql (limited to 'synapse') diff --git a/changelog.d/14534.misc b/changelog.d/14534.misc new file mode 100644 index 0000000000..5fe79042e5 --- /dev/null +++ b/changelog.d/14534.misc @@ -0,0 +1 @@ +Improve DB performance by reducing amount of data that gets read in `device_lists_changes_in_room`. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 37629115ab..05a193f889 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1441,6 +1441,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): self._remove_duplicate_outbound_pokes, ) + self.db_pool.updates.register_background_index_update( + "device_lists_changes_in_room_by_room_index", + index_name="device_lists_changes_in_room_by_room_idx", + table="device_lists_changes_in_room", + columns=["room_id", "stream_id"], + ) + async def _drop_device_list_streams_non_unique_indexes( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql b/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql new file mode 100644 index 0000000000..3725022a13 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/13add_device_lists_index.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Adds an index on `device_lists_changes_in_room (room_id, stream_id)`, which +-- speeds up `/sync` queries. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7313, 'device_lists_changes_in_room_by_room_index', '{}'); -- cgit 1.5.1 From 3b4e1508689cc09eba30509249459a64431558fc Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 24 Nov 2022 09:10:47 +0100 Subject: Faster joins: use servers list approximation in `assert_host_in_room` (#14515) Signed-off-by: Mathieu Velten --- changelog.d/14515.misc | 1 + synapse/handlers/event_auth.py | 28 +++++++++++++++++----------- 2 files changed, 18 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14515.misc (limited to 'synapse') diff --git a/changelog.d/14515.misc b/changelog.d/14515.misc new file mode 100644 index 0000000000..a0effb4dbe --- /dev/null +++ b/changelog.d/14515.misc @@ -0,0 +1 @@ +Faster joins: use servers list approximation received during `send_join` (potentially updated with received membership events) in `assert_host_in_room`. \ No newline at end of file diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 3bbad0271b..f91dbbecb7 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -45,6 +45,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main + self._state_storage_controller = hs.get_storage_controllers().state self._server_name = hs.hostname async def check_auth_rules_from_context( @@ -179,17 +180,22 @@ class EventAuthHandler: this function may return an incorrect result as we are not able to fully track server membership in a room without full state. """ - if not allow_partial_state_rooms and await self._store.is_partial_state_room( - room_id - ): - raise AuthError( - 403, - "Unable to authorise you right now; room is partial-stated here.", - errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, - ) - - if not await self.is_host_in_room(room_id, host): - raise AuthError(403, "Host not in room.") + if await self._store.is_partial_state_room(room_id): + if allow_partial_state_rooms: + current_hosts = await self._state_storage_controller.get_current_hosts_in_room_or_partial_state_approximation( + room_id + ) + if host not in current_hosts: + raise AuthError(403, "Host not in room (partial-state approx).") + else: + raise AuthError( + 403, + "Unable to authorise you right now; room is partial-stated here.", + errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, + ) + else: + if not await self.is_host_in_room(room_id, host): + raise AuthError(403, "Host not in room.") async def check_restricted_join_rules( self, -- cgit 1.5.1 From 9af2be192a759c22d189b72cc0a7580cd9de8a37 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 24 Nov 2022 09:09:17 +0000 Subject: Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. (#14538) --- changelog.d/14538.removal | 1 + docs/upgrade.md | 22 ++ docs/usage/configuration/config_documentation.md | 25 -- synapse/app/_base.py | 16 +- synapse/app/generic_worker.py | 1 - synapse/app/homeserver.py | 1 - synapse/config/metrics.py | 2 - synapse/metrics/__init__.py | 7 +- synapse/metrics/_legacy_exposition.py | 288 ----------------------- synapse/metrics/_twisted_exposition.py | 38 +++ tests/storage/test_event_metrics.py | 7 +- 11 files changed, 70 insertions(+), 338 deletions(-) create mode 100644 changelog.d/14538.removal delete mode 100644 synapse/metrics/_legacy_exposition.py create mode 100644 synapse/metrics/_twisted_exposition.py (limited to 'synapse') diff --git a/changelog.d/14538.removal b/changelog.d/14538.removal new file mode 100644 index 0000000000..d2035ce82a --- /dev/null +++ b/changelog.d/14538.removal @@ -0,0 +1 @@ +Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 2aa353e496..4fe9e4f02e 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,28 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.73.0 + +## Legacy Prometheus metric names have now been removed + +Synapse v1.69.0 included the deprecation of legacy Prometheus metric names +and offered an option to disable them. +Synapse v1.71.0 disabled legacy Prometheus metric names by default. + +This version, v1.73.0, removes those legacy Prometheus metric names entirely. +This also means that the `enable_legacy_metrics` configuration option has been +removed; it will no longer be possible to re-enable the legacy metric names. + +If you use metrics and have not yet updated your Grafana dashboard(s), +Prometheus console(s) or alerting rule(s), please consider doing so when upgrading +to this version. +Note that the included Grafana dashboard was updated in v1.72.0 to correct some +metric names which were missed when legacy metrics were disabled by default. + +See [v1.69.0: Deprecation of legacy Prometheus metric names](#deprecation-of-legacy-prometheus-metric-names) +for more context. + + # Upgrading to v1.72.0 ## Dropping support for PostgreSQL 10 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index f5937dd902..fae2771fad 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2437,31 +2437,6 @@ Example configuration: enable_metrics: true ``` --- -### `enable_legacy_metrics` - -Set to `true` to publish both legacy and non-legacy Prometheus metric names, -or to `false` to only publish non-legacy Prometheus metric names. -Defaults to `false`. Has no effect if `enable_metrics` is `false`. -**In Synapse v1.67.0 up to and including Synapse v1.70.1, this defaulted to `true`.** - -Legacy metric names include: -- metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; -- counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. - -These legacy metric names are unconventional and not compliant with OpenMetrics standards. -They are included for backwards compatibility. - -Example configuration: -```yaml -enable_legacy_metrics: false -``` - -See https://github.com/matrix-org/synapse/issues/11106 for context. - -*Since v1.67.0.* - -**Will be removed in v1.73.0.** ---- ### `sentry` Use this option to enable sentry integration. Provide the DSN assigned to you by sentry diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 41d2732ef9..a5aa2185a2 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -266,26 +266,18 @@ def register_start( reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics( - bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool -) -> None: +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ from prometheus_client import start_http_server as start_http_server_prometheus - from synapse.metrics import ( - RegistryProxy, - start_http_server as start_http_server_legacy, - ) + from synapse.metrics import RegistryProxy for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - if enable_legacy_metric_names: - start_http_server_legacy(port, addr=host, registry=RegistryProxy) - else: - _set_prometheus_client_use_created_metrics(False) - start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + _set_prometheus_client_use_created_metrics(False) + start_http_server_prometheus(port, addr=host, registry=RegistryProxy) def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 74909b7d4a..46dc731696 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -320,7 +320,6 @@ class GenericWorkerServer(HomeServer): _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: logger.warning("Unsupported listener type: %s", listener.type) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 4f4fee4782..b9be558c7e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -265,7 +265,6 @@ class SynapseHomeServer(HomeServer): _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: # this shouldn't happen, as the listener type should have been checked diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 6034a0346e..8c1c9bd12d 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -43,8 +43,6 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) - self.enable_legacy_metrics = config.get("enable_legacy_metrics", False) - self.report_stats = config.get("report_stats", None) self.report_stats_endpoint = config.get( "report_stats_endpoint", "https://matrix.org/report-usage-stats/push" diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index c3d3daf877..b01372565d 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -47,11 +47,7 @@ from twisted.python.threadpool import ThreadPool # This module is imported for its side effects; flake8 needn't warn that it's unused. import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager -from synapse.metrics._legacy_exposition import ( - MetricsResource, - generate_latest, - start_http_server, -) +from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._types import Collector from synapse.util import SYNAPSE_VERSION @@ -474,7 +470,6 @@ __all__ = [ "Collector", "MetricsResource", "generate_latest", - "start_http_server", "LaterGauge", "InFlightGauge", "GaugeBucketCollector", diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py deleted file mode 100644 index 1459f9d224..0000000000 --- a/synapse/metrics/_legacy_exposition.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2015-2019 Prometheus Python Client Developers -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This code is based off `prometheus_client/exposition.py` from version 0.7.1. - -Due to the renaming of metrics in prometheus_client 0.4.0, this customised -vendoring of the code will emit both the old versions that Synapse dashboards -expect, and the newer "best practice" version of the up-to-date official client. -""" -import logging -import math -import threading -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn -from typing import Any, Dict, List, Type, Union -from urllib.parse import parse_qs, urlparse - -from prometheus_client import REGISTRY, CollectorRegistry -from prometheus_client.core import Sample - -from twisted.web.resource import Resource -from twisted.web.server import Request - -logger = logging.getLogger(__name__) -CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" - - -def floatToGoString(d: Union[int, float]) -> str: - d = float(d) - if d == math.inf: - return "+Inf" - elif d == -math.inf: - return "-Inf" - elif math.isnan(d): - return "NaN" - else: - s = repr(d) - dot = s.find(".") - # Go switches to exponents sooner than Python. - # We only need to care about positive values for le/quantile. - if d > 0 and dot > 6: - mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") - return f"{mantissa}e+0{dot - 1}" - return s - - -def sample_line(line: Sample, name: str) -> str: - if line.labels: - labelstr = "{{{0}}}".format( - ",".join( - [ - '{}="{}"'.format( - k, - v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), - ) - for k, v in sorted(line.labels.items()) - ] - ) - ) - else: - labelstr = "" - timestamp = "" - if line.timestamp is not None: - # Convert to milliseconds. - timestamp = f" {int(float(line.timestamp) * 1000):d}" - return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) - - -# Mapping from new metric names to legacy metric names. -# We translate these back to their old names when exposing them through our -# legacy vendored exporter. -# Only this legacy exposition module applies these name changes. -LEGACY_METRIC_NAMES = { - "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits", - "synapse_util_caches_cache_size": "synapse_util_caches_cache:size", - "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size", - "synapse_util_caches_cache": "synapse_util_caches_cache:total", - "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size", - "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits", - "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size", - "synapse_util_caches_response_cache": "synapse_util_caches_response_cache:total", - "synapse_federation_client_sent_pdu_destinations": "synapse_federation_client_sent_pdu_destinations:total", - "synapse_federation_client_sent_pdu_destinations_count": "synapse_federation_client_sent_pdu_destinations:count", - "synapse_admin_mau_current": "synapse_admin_mau:current", - "synapse_admin_mau_max": "synapse_admin_mau:max", - "synapse_admin_mau_registered_reserved_users": "synapse_admin_mau:registered_reserved_users", -} - - -def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: - """ - Generate metrics in legacy format. Modern metrics are generated directly - by prometheus-client. - """ - - output = [] - - for metric in registry.collect(): - if not metric.samples: - # No samples, don't bother. - continue - - # Translate to legacy metric name if it has one. - mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name) - mnewname = metric.name - mtype = metric.type - - # OpenMetrics -> Prometheus - if mtype == "counter": - mnewname = mnewname + "_total" - elif mtype == "info": - mtype = "gauge" - mnewname = mnewname + "_info" - elif mtype == "stateset": - mtype = "gauge" - elif mtype == "gaugehistogram": - mtype = "histogram" - elif mtype == "unknown": - mtype = "untyped" - - # Output in the old format for compatibility. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname} {mtype}\n") - - om_samples: Dict[str, List[str]] = {} - for s in metric.samples: - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - # OpenMetrics specific sample, put in a gauge at the end. - # (these come from gaugehistograms which don't get renamed, - # so no need to faff with mnewname) - om_samples.setdefault(suffix, []).append(sample_line(s, s.name)) - break - else: - newname = s.name.replace(mnewname, mname) - if ":" in newname and newname.endswith("_total"): - newname = newname[: -len("_total")] - output.append(sample_line(s, newname)) - - for suffix, lines in sorted(om_samples.items()): - if emit_help: - output.append( - "# HELP {}{} {}\n".format( - mname, - suffix, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname}{suffix} gauge\n") - output.extend(lines) - - # Get rid of the weird colon things while we're at it - if mtype == "counter": - mnewname = mnewname.replace(":total", "") - mnewname = mnewname.replace(":", "_") - - if mname == mnewname: - continue - - # Also output in the new format, if it's different. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mnewname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mnewname} {mtype}\n") - - for s in metric.samples: - # Get rid of the OpenMetrics specific samples (we should already have - # dealt with them above anyway.) - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - break - else: - sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name) - output.append( - sample_line(s, sample_name.replace(":total", "").replace(":", "_")) - ) - - return "".join(output).encode("utf-8") - - -class MetricsHandler(BaseHTTPRequestHandler): - """HTTP handler that gives metrics from ``REGISTRY``.""" - - registry = REGISTRY - - def do_GET(self) -> None: - registry = self.registry - params = parse_qs(urlparse(self.path).query) - - if "help" in params: - emit_help = True - else: - emit_help = False - - try: - output = generate_latest(registry, emit_help=emit_help) - except Exception: - self.send_error(500, "error generating metric output") - raise - try: - self.send_response(200) - self.send_header("Content-Type", CONTENT_TYPE_LATEST) - self.send_header("Content-Length", str(len(output))) - self.end_headers() - self.wfile.write(output) - except BrokenPipeError as e: - logger.warning( - "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e - ) - - def log_message(self, format: str, *args: Any) -> None: - """Log nothing.""" - - @classmethod - def factory(cls, registry: CollectorRegistry) -> Type: - """Returns a dynamic MetricsHandler class tied - to the passed registry. - """ - # This implementation relies on MetricsHandler.registry - # (defined above and defaulted to REGISTRY). - - # As we have unicode_literals, we need to create a str() - # object for type(). - cls_name = str(cls.__name__) - MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry}) - return MyMetricsHandler - - -class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): - """Thread per request HTTP server.""" - - # Make worker threads "fire and forget". Beginning with Python 3.7 this - # prevents a memory leak because ``ThreadingMixIn`` starts to gather all - # non-daemon threads in a list in order to join on them at server close. - # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the - # same as Python 3.7's ``ThreadingHTTPServer``. - daemon_threads = True - - -def start_http_server( - port: int, addr: str = "", registry: CollectorRegistry = REGISTRY -) -> None: - """Starts an HTTP server for prometheus metrics as a daemon thread""" - CustomMetricsHandler = MetricsHandler.factory(registry) - httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) - t = threading.Thread(target=httpd.serve_forever) - t.daemon = True - t.start() - - -class MetricsResource(Resource): - """ - Twisted ``Resource`` that serves prometheus metrics. - """ - - isLeaf = True - - def __init__(self, registry: CollectorRegistry = REGISTRY): - self.registry = registry - - def render_GET(self, request: Request) -> bytes: - request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) - response = generate_latest(self.registry) - request.setHeader(b"Content-Length", str(len(response))) - return response diff --git a/synapse/metrics/_twisted_exposition.py b/synapse/metrics/_twisted_exposition.py new file mode 100644 index 0000000000..0abcd14953 --- /dev/null +++ b/synapse/metrics/_twisted_exposition.py @@ -0,0 +1,38 @@ +# Copyright 2015-2019 Prometheus Python Client Developers +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from prometheus_client import REGISTRY, CollectorRegistry, generate_latest + +from twisted.web.resource import Resource +from twisted.web.server import Request + +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" + + +class MetricsResource(Resource): + """ + Twisted ``Resource`` that serves prometheus metrics. + """ + + isLeaf = True + + def __init__(self, registry: CollectorRegistry = REGISTRY): + self.registry = registry + + def render_GET(self, request: Request) -> bytes: + request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) + response = generate_latest(self.registry) + request.setHeader(b"Content-Length", str(len(response))) + return response diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 088fbb247b..6f1135eef4 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from prometheus_client import generate_latest -from synapse.metrics import REGISTRY, generate_latest +from synapse.metrics import REGISTRY from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -53,8 +54,8 @@ class ExtremStatisticsTestCase(HomeserverTestCase): items = list( filter( - lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY, emit_help=False).split(b"\n"), + lambda x: b"synapse_forward_extremities_" in x and b"# HELP" not in x, + generate_latest(REGISTRY).split(b"\n"), ) ) -- cgit 1.5.1 From f6c74d1cb2ed966802b01a2b037f09ce7a842c18 Mon Sep 17 00:00:00 2001 From: Benjamin Kampmann Date: Thu, 24 Nov 2022 09:10:51 +0000 Subject: Implement message forward pagination from start when no from is given, fixes #12383 (#14149) Fixes https://github.com/matrix-org/synapse/issues/12383 --- changelog.d/14149.bugfix | 1 + synapse/handlers/pagination.py | 6 ++++++ synapse/streams/events.py | 13 +++++++++++++ tests/rest/admin/test_room.py | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+) create mode 100644 changelog.d/14149.bugfix (limited to 'synapse') diff --git a/changelog.d/14149.bugfix b/changelog.d/14149.bugfix new file mode 100644 index 0000000000..b31c658266 --- /dev/null +++ b/changelog.d/14149.bugfix @@ -0,0 +1 @@ +Fix #12383: paginate room messages from the start if no from is given. Contributed by @gnunicorn . \ No newline at end of file diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a4ca9cb8b4..c572508a02 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -448,6 +448,12 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token + elif pagin_config.direction == "f": + from_token = ( + await self.hs.get_event_sources().get_start_token_for_pagination( + room_id + ) + ) else: from_token = ( await self.hs.get_event_sources().get_current_token_for_pagination( diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f331e1af16..619eb7f601 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -73,6 +73,19 @@ class EventSources: ) return token + @trace + async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: + """Get the start token for a given room to be used to paginate + events. + + The returned token does not have the current values for fields other + than `room`, since they are not used during pagination. + + Returns: + The start token for pagination. + """ + return StreamToken.START + @trace async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: """Get the current token for a given room to be used to paginate diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d156be82b0..e0f5d54aba 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1857,6 +1857,46 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertIn("end", channel.json_body) + def test_room_messages_backward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in backwards, this is the first event + self.assertEqual(chunk[0]["event_id"], latest_event_id) + + def test_room_messages_forward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in forward, this is the last event + self.assertEqual(chunk[5]["event_id"], latest_event_id) + def test_room_messages_purge(self) -> None: """Test room messages can be retrieved by an admin that isn't in the room.""" store = self.hs.get_datastores().main -- cgit 1.5.1 From c2e06c36d4ac2aef9de1a192cdcf9964415d09d2 Mon Sep 17 00:00:00 2001 From: schmop Date: Thu, 24 Nov 2022 11:49:04 +0100 Subject: Fix crash admin media list api when info is None (#14537) Fixes https://github.com/matrix-org/synapse/issues/14536 --- changelog.d/14537.bugfix | 1 + synapse/storage/databases/main/room.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14537.bugfix (limited to 'synapse') diff --git a/changelog.d/14537.bugfix b/changelog.d/14537.bugfix new file mode 100644 index 0000000000..d7ce78d032 --- /dev/null +++ b/changelog.d/14537.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the [List media admin API](https://matrix-org.github.io/synapse/latest/admin_api/media_admin_api.html#list-all-media-in-a-room) would fail when processing an image with broken thumbnail information. \ No newline at end of file diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4fbaefad73..52ad947c6c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -912,7 +912,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") + info = content.get("info") + if isinstance(info, dict): + thumbnail_url = info.get("thumbnail_url") + else: + thumbnail_url = None for url in (content_url, thumbnail_url): if not url: -- cgit 1.5.1 From 39cde585bf1e6cf3d32af9302437b37bae7a64b8 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 24 Nov 2022 18:09:47 +0100 Subject: Faster joins: use initial list of servers if we don't have the full state yet (#14408) Signed-off-by: Mathieu Velten Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14408.misc | 1 + synapse/federation/sender/__init__.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14408.misc (limited to 'synapse') diff --git a/changelog.d/14408.misc b/changelog.d/14408.misc new file mode 100644 index 0000000000..2c77d97591 --- /dev/null +++ b/changelog.d/14408.misc @@ -0,0 +1 @@ +Faster joins: send events to initial list of servers if we don't have the full state yet. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 3ad483efe0..fc1d8c88a7 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -434,7 +434,23 @@ class FederationSender(AbstractFederationSender): # If there are no prev event IDs then the state is empty # and so no remote servers in the room destinations = set() - else: + + if destinations is None: + # During partial join we use the set of servers that we got + # when beginning the join. It's still possible that we send + # events to servers that left the room in the meantime, but + # we consider that an acceptable risk since it is only our own + # events that we leak and not other server's ones. + partial_state_destinations = ( + await self.store.get_partial_state_servers_at_join( + event.room_id + ) + ) + + if len(partial_state_destinations) > 0: + destinations = partial_state_destinations + + if destinations is None: # We check the external cache for the destinations, which is # stored per state group. -- cgit 1.5.1 From 09de2aecb05cb46e0513396e2675b24c8beedb68 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 25 Nov 2022 19:16:50 +0400 Subject: Add support for handling avatar with SSO login (#13917) This commit adds support for handling a provided avatar picture URL when logging in via SSO. Signed-off-by: Ashish Kumar Fixes #9357. --- changelog.d/13917.feature | 1 + docs/usage/configuration/config_documentation.md | 9 +- mypy.ini | 4 +- synapse/handlers/oidc.py | 7 ++ synapse/handlers/sso.py | 111 +++++++++++++++++ tests/handlers/test_sso.py | 145 +++++++++++++++++++++++ 6 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13917.feature create mode 100644 tests/handlers/test_sso.py (limited to 'synapse') diff --git a/changelog.d/13917.feature b/changelog.d/13917.feature new file mode 100644 index 0000000000..4eb942ab38 --- /dev/null +++ b/changelog.d/13917.feature @@ -0,0 +1 @@ +Adds support for handling avatar in SSO login. Contributed by @ashfame. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index fae2771fad..749af12aac 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2968,10 +2968,17 @@ Options for each entry include: For the default provider, the following settings are available: - * subject_claim: name of the claim containing a unique identifier + * `subject_claim`: name of the claim containing a unique identifier for the user. Defaults to 'sub', which OpenID Connect compliant providers should provide. + * `picture_claim`: name of the claim containing an url for the user's profile picture. + Defaults to 'picture', which OpenID Connect compliant providers should provide + and has to refer to a direct image file such as PNG, JPEG, or GIF image file. + + Currently only supported in monolithic (single-process) server configurations + where the media repository runs within the Synapse process. + * `localpart_template`: Jinja2 template for the localpart of the MXID. If this is not set, the user will be prompted to choose their own username (see the documentation for the `sso_auth_account_details.html` diff --git a/mypy.ini b/mypy.ini index 25b3c93748..0b6e7df267 100644 --- a/mypy.ini +++ b/mypy.ini @@ -119,6 +119,9 @@ disallow_untyped_defs = True [mypy-tests.storage.test_profile] disallow_untyped_defs = True +[mypy-tests.handlers.test_sso] +disallow_untyped_defs = True + [mypy-tests.storage.test_user_directory] disallow_untyped_defs = True @@ -137,7 +140,6 @@ disallow_untyped_defs = False [mypy-tests.utils] disallow_untyped_defs = True - ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. ;; The `typeshed` project maintains stubs here: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 41c675f408..03de6a4ba6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict): localpart: Optional[str] confirm_localpart: bool display_name: Optional[str] + picture: Optional[str] # may be omitted by older `OidcMappingProviders` emails: List[str] @@ -1520,6 +1521,7 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: subject_claim: str + picture_claim: str localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1539,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") + picture_claim = config.get("picture_claim", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1572,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return JinjaOidcMappingConfig( subject_claim=subject_claim, + picture_claim=picture_claim, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1611,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) + picture = userinfo.get("picture") + return UserAttributeDict( localpart=localpart, display_name=display_name, emails=emails, + picture=picture, confirm_localpart=self._config.confirm_localpart, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e1c0bff1b2..44e70fc4b8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import hashlib +import io import logging from typing import ( TYPE_CHECKING, @@ -138,6 +140,7 @@ class UserAttributes: localpart: Optional[str] confirm_localpart: bool = False display_name: Optional[str] = None + picture: Optional[str] = None emails: Collection[str] = attr.Factory(list) @@ -196,6 +199,10 @@ class SsoHandler: self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() + self._media_repo = ( + hs.get_media_repository() if hs.config.media.can_load_media_repo else None + ) + self._http_client = hs.get_proxied_blacklisted_http_client() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. @@ -495,6 +502,8 @@ class SsoHandler: await self._profile_handler.set_displayname( user_id_obj, requester, attributes.display_name, True ) + if attributes.picture: + await self.set_avatar(user_id, attributes.picture) await self._auth_handler.complete_sso_login( user_id, @@ -703,8 +712,110 @@ class SsoHandler: await self._store.record_user_external_id( auth_provider_id, remote_user_id, registered_user_id ) + + # Set avatar, if available + if attributes.picture: + await self.set_avatar(registered_user_id, attributes.picture) + return registered_user_id + async def set_avatar(self, user_id: str, picture_https_url: str) -> bool: + """Set avatar of the user. + + This downloads the image file from the URL provided, stores that in + the media repository and then sets the avatar on the user's profile. + + It can detect if the same image is being saved again and bails early by storing + the hash of the file in the `upload_name` of the avatar image. + + Currently, it only supports server configurations which run the media repository + within the same process. + + It silently fails and logs a warning by raising an exception and catching it + internally if: + * it is unable to fetch the image itself (non 200 status code) or + * the image supplied is bigger than max allowed size or + * the image type is not one of the allowed image types. + + Args: + user_id: matrix user ID in the form @localpart:domain as a string. + + picture_https_url: HTTPS url for the picture image file. + + Returns: `True` if the user's avatar has been successfully set to the image at + `picture_https_url`. + """ + if self._media_repo is None: + logger.info( + "failed to set user avatar because out-of-process media repositories " + "are not supported yet " + ) + return False + + try: + uid = UserID.from_string(user_id) + + def is_allowed_mime_type(content_type: str) -> bool: + if ( + self._profile_handler.allowed_avatar_mimetypes + and content_type + not in self._profile_handler.allowed_avatar_mimetypes + ): + return False + return True + + # download picture, enforcing size limit & mime type check + picture = io.BytesIO() + + content_length, headers, uri, code = await self._http_client.get_file( + url=picture_https_url, + output_stream=picture, + max_size=self._profile_handler.max_avatar_size, + is_allowed_content_type=is_allowed_mime_type, + ) + + if code != 200: + raise Exception( + "GET request to download sso avatar image returned {}".format(code) + ) + + # upload name includes hash of the image file's content so that we can + # easily check if it requires an update or not, the next time user logs in + upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest() + + # bail if user already has the same avatar + profile = await self._profile_handler.get_profile(user_id) + if profile["avatar_url"] is not None: + server_name = profile["avatar_url"].split("/")[-2] + media_id = profile["avatar_url"].split("/")[-1] + if server_name == self._server_name: + media = await self._media_repo.store.get_local_media(media_id) + if media is not None and upload_name == media["upload_name"]: + logger.info("skipping saving the user avatar") + return True + + # store it in media repository + avatar_mxc_url = await self._media_repo.create_content( + media_type=headers[b"Content-Type"][0].decode("utf-8"), + upload_name=upload_name, + content=picture, + content_length=content_length, + auth_user=uid, + ) + + # save it as user avatar + await self._profile_handler.set_avatar_url( + uid, + create_requester(uid), + str(avatar_mxc_url), + ) + + logger.info("successfully saved the user avatar") + return True + except Exception: + logger.warning("failed to save the user avatar") + return False + async def complete_sso_ui_auth_request( self, auth_provider_id: str, diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py new file mode 100644 index 0000000000..137deab138 --- /dev/null +++ b/tests/handlers/test_sso.py @@ -0,0 +1,145 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import BinaryIO, Callable, Dict, List, Optional, Tuple +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers + +from synapse.api.errors import Codes, SynapseError +from synapse.http.client import RawHeaders +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.test_utils import SMALL_PNG, FakeResponse + + +class TestSSOHandler(unittest.HomeserverTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + self.http_client.user_agent = b"Synapse Test" + hs = self.setup_test_homeserver( + proxied_blacklisted_http_client=self.http_client + ) + return hs + + async def test_set_avatar(self) -> None: + """Tests successfully setting the avatar of a newly created user""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # Ensure avatar is set on this newly created user, + # so no need to compare for the exact image + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertIsNot(profile["avatar_url"], None) + + @unittest.override_config({"max_avatar_size": 1}) + async def test_set_avatar_too_big_image(self) -> None: + """Tests that saving an avatar fails when it is too big""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]}) + async def test_set_avatar_incorrect_mime_type(self) -> None: + """Tests that saving an avatar fails when its mime type is not allowed""" + handler = self.hs.get_sso_handler() + + # any random user works since image check is supposed to fail + user_id = "@sso-user:test" + + self.assertFalse( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + async def test_skip_saving_avatar_when_not_changed(self) -> None: + """Tests whether saving of avatar correctly skips if the avatar hasn't + changed""" + handler = self.hs.get_sso_handler() + + # Create a new user to set avatar for + reg_handler = self.hs.get_registration_handler() + user_id = self.get_success(reg_handler.register_user(approved=True)) + + # set avatar for the first time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # get avatar picture for comparison after another attempt + profile_handler = self.hs.get_profile_handler() + profile = self.get_success(profile_handler.get_profile(user_id)) + url_to_match = profile["avatar_url"] + + # set same avatar for the second time, should be a success + self.assertTrue( + self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) + ) + + # compare avatar picture's url from previous step + profile = self.get_success(profile_handler.get_profile(user_id)) + self.assertEqual(profile["avatar_url"], url_to_match) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + + fake_response = FakeResponse(code=404) + if url == "http://my.server/me.png": + fake_response = FakeResponse( + code=200, + headers=Headers( + {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]} + ), + body=SMALL_PNG, + ) + + if max_size is not None and max_size < len(SMALL_PNG): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, + ) + + if is_allowed_content_type and not is_allowed_content_type("image/png"): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + ( + "Requested file's content type not allowed for this operation: %s" + % "image/png" + ), + ) + + output_stream.write(fake_response.body) + + return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200 -- cgit 1.5.1 From f792dd74e1e6f64cb15d920d87818f47f17e7848 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 28 Nov 2022 13:42:06 +0000 Subject: Remove option to skip locking of tables during emulated upserts (#14469) To perform an emulated upsert into a table safely, we must either: * lock the table, * be the only writer upserting into the table * or rely on another unique index being present. When the 2nd or 3rd cases were applicable, we previously avoided locking the table as an optimization. However, as seen in #14406, it is easy to slip up when adding new schema deltas and corrupt the database. The only time we lock when performing emulated upserts is while waiting for background updates on postgres. On sqlite, we do no locking at all. Let's remove the option to skip locking tables, so that we don't shoot ourselves in the foot again. Signed-off-by: Sean Quah --- changelog.d/14469.misc | 1 + synapse/storage/database.py | 56 +++++++--------------- synapse/storage/databases/main/account_data.py | 8 ---- synapse/storage/databases/main/appservice.py | 2 - synapse/storage/databases/main/devices.py | 9 ---- synapse/storage/databases/main/event_federation.py | 1 - synapse/storage/databases/main/pusher.py | 6 --- synapse/storage/databases/main/room.py | 6 --- synapse/storage/databases/main/room_batch.py | 2 - synapse/storage/databases/main/user_directory.py | 2 - 10 files changed, 19 insertions(+), 74 deletions(-) create mode 100644 changelog.d/14469.misc (limited to 'synapse') diff --git a/changelog.d/14469.misc b/changelog.d/14469.misc new file mode 100644 index 0000000000..a12a21e9ae --- /dev/null +++ b/changelog.d/14469.misc @@ -0,0 +1 @@ +Remove option to skip locking of tables when performing emulated upserts, to avoid a class of bugs in future. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a14b13aec8..55bcb90001 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1129,7 +1129,6 @@ class DatabasePool: values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", - lock: bool = True, ) -> bool: """Insert a row with values + insertion_values; on conflict, update with values. @@ -1154,21 +1153,12 @@ class DatabasePool: requiring that a unique index exist on the column names used to detect a conflict (i.e. `keyvalues.keys()`). - If there is no such index, we can "emulate" an upsert with a SELECT followed - by either an INSERT or an UPDATE. This is unsafe: we cannot make the same - atomicity guarantees that a native upsert can and are very vulnerable to races - and crashes. Therefore if we wish to upsert without an appropriate unique index, - we must either: - - 1. Acquire a table-level lock before the emulated upsert (`lock=True`), or - 2. VERY CAREFULLY ensure that we are the only thread and worker which will be - writing to this table, in which case we can proceed without a lock - (`lock=False`). - - Generally speaking, you should use `lock=True`. If the table in question has a - unique index[*], this class will use a native upsert (which is atomic and so can - ignore the `lock` argument). Otherwise this class will use an emulated upsert, - in which case we want the safer option unless we been VERY CAREFUL. + If there is no such index yet[*], we can "emulate" an upsert with a SELECT + followed by either an INSERT or an UPDATE. This is unsafe unless *all* upserters + run at the SERIALIZABLE isolation level: we cannot make the same atomicity + guarantees that a native upsert can and are very vulnerable to races and + crashes. Therefore to upsert without an appropriate unique index, we acquire a + table-level lock before the emulated upsert. [*]: Some tables have unique indices added to them in the background. Those tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES, @@ -1189,7 +1179,6 @@ class DatabasePool: values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting desc: description of the transaction, for logging and metrics - lock: True to lock the table when doing the upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) @@ -1209,7 +1198,6 @@ class DatabasePool: keyvalues, values, insertion_values, - lock=lock, db_autocommit=autocommit, ) except self.engine.module.IntegrityError as e: @@ -1232,7 +1220,6 @@ class DatabasePool: values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, where_clause: Optional[str] = None, - lock: bool = True, ) -> bool: """ Pick the UPSERT method which works best on the platform. Either the @@ -1245,8 +1232,6 @@ class DatabasePool: values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting where_clause: An index predicate to apply to the upsert. - lock: True to lock the table when doing the upsert. Unused when performing - a native upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) @@ -1270,7 +1255,6 @@ class DatabasePool: values, insertion_values=insertion_values, where_clause=where_clause, - lock=lock, ) def simple_upsert_txn_emulated( @@ -1291,14 +1275,15 @@ class DatabasePool: insertion_values: additional key/values to use only when inserting where_clause: An index predicate to apply to the upsert. lock: True to lock the table when doing the upsert. + Must not be False unless the table has already been locked. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) """ insertion_values = insertion_values or {} - # We need to lock the table :(, unless we're *really* careful if lock: + # We need to lock the table :( self.engine.lock_table(txn, table) def _getwhere(key: str) -> str: @@ -1406,7 +1391,6 @@ class DatabasePool: value_names: Collection[str], value_values: Collection[Collection[Any]], desc: str, - lock: bool = True, ) -> None: """ Upsert, many times. @@ -1418,8 +1402,6 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. Unused when performing - a native upsert. """ # We can autocommit if it safe to upsert @@ -1433,7 +1415,6 @@ class DatabasePool: key_values, value_names, value_values, - lock=lock, db_autocommit=autocommit, ) @@ -1445,7 +1426,6 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], - lock: bool = True, ) -> None: """ Upsert, many times. @@ -1457,8 +1437,6 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. Unused when performing - a native upsert. """ if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( @@ -1466,7 +1444,12 @@ class DatabasePool: ) else: return self.simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values, lock=lock + txn, + table, + key_names, + key_values, + value_names, + value_values, ) def simple_upsert_many_txn_emulated( @@ -1477,7 +1460,6 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], - lock: bool = True, ) -> None: """ Upsert, many times, but without native UPSERT support or batching. @@ -1489,18 +1471,16 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. """ # No value columns, therefore make a blank list so that the following # zip() works correctly. if not value_names: value_values = [() for x in range(len(key_values))] - if lock: - # Lock the table just once, to prevent it being done once per row. - # Note that, according to Postgres' documentation, once obtained, - # the lock is held for the remainder of the current transaction. - self.engine.lock_table(txn, "user_ips") + # Lock the table just once, to prevent it being done once per row. + # Note that, according to Postgres' documentation, once obtained, + # the lock is held for the remainder of the current transaction. + self.engine.lock_table(txn, "user_ips") for keyv, valv in zip(key_values, value_values): _keys = {x: y for x, y in zip(key_names, keyv)} diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 282687ebce..07908c41d9 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -449,9 +449,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so simple_upsert will - # retry if there is a conflict. await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", @@ -461,7 +458,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) "account_data_type": account_data_type, }, values={"stream_id": next_id, "content": content_json}, - lock=False, ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) @@ -517,15 +513,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) -> None: content_json = json_encoder.encode(content) - # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so simple_upsert will retry if - # there is a conflict. self.db_pool.simple_upsert_txn( txn, table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, values={"stream_id": next_id, "content": content_json}, - lock=False, ) # Ignored users get denormalized into a separate table as an optimisation. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 63046c0527..25da0c56c5 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -451,8 +451,6 @@ class ApplicationServiceTransactionWorkerStore( table="application_services_state", keyvalues={"as_id": service.id}, values={f"{stream_type}_stream_id": pos}, - # no need to lock when emulating upsert: as_id is a unique key - lock=False, desc="set_appservice_stream_type_pos", ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 05a193f889..534f7fc04a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1744,9 +1744,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, values={"content": json_encoder.encode(content)}, - # we don't need to lock, because we assume we are the only thread - # updating this user's devices. - lock=False, ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) @@ -1760,9 +1757,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, - # again, we can assume we are the only thread updating this user's - # extremity. - lock=False, ) async def update_remote_device_list_cache( @@ -1815,9 +1809,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, - # we don't need to lock, because we can assume we are the only thread - # updating this user's extremity. - lock=False, ) async def add_device_change_to_streams( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 309a4ba664..bbee02ab18 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1686,7 +1686,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas }, insertion_values={}, desc="insert_insertion_extremity", - lock=False, ) async def insert_received_event_to_staging( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index fee37b9ce4..40fd781a6a 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -325,14 +325,11 @@ class PusherWorkerStore(SQLBaseStore): async def set_throttle_params( self, pusher_id: str, room_id: str, params: ThrottleParams ) -> None: - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so simple_upsert will retry await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms}, desc="set_throttle_params", - lock=False, ) async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int: @@ -589,8 +586,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: - # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -609,7 +604,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): "device_id": device_id, }, desc="add_pusher", - lock=False, ) user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 52ad947c6c..1309bfd374 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1847,9 +1847,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, ) async def store_partial_state_room( @@ -1970,9 +1967,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "creator": "", "has_auth_chain_index": has_auth_chain_index, }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, ) async def set_room_is_public(self, room_id: str, is_public: bool) -> None: diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index 39e80f6f5b..131f357d04 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -44,6 +44,4 @@ class RoomBatchStore(SQLBaseStore): table="event_to_state_groups", keyvalues={"event_id": event_id}, values={"state_group": state_group_id, "event_id": event_id}, - # Unique constraint on event_id so we don't have to lock - lock=False, ) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 698d6f7515..044435deab 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -481,7 +481,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): table="user_directory", keyvalues={"user_id": user_id}, values={"display_name": display_name, "avatar_url": avatar_url}, - lock=False, # We're only inserter ) if isinstance(self.database_engine, PostgresEngine): @@ -511,7 +510,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): table="user_directory_search", keyvalues={"user_id": user_id}, values={"value": value}, - lock=False, # We're only inserter ) else: # This should be unreachable. -- cgit 1.5.1 From d748bbc8f8268d2e8457374d529adafb20b9f5f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Nov 2022 09:40:17 -0500 Subject: Include thread information when sending receipts over federation. (#14466) Include the thread_id field when sending read receipts over federation. This might result in the same user having multiple read receipts per-room, meaning multiple EDUs must be sent to encapsulate those receipts. This restructures the PerDestinationQueue APIs to support multiple receipt EDUs, queue_read_receipt now becomes linear time in the number of queued threaded receipts in the room for the given user, it is expected this is a small number since receipt EDUs are sent as filler in transactions. --- changelog.d/14466.bugfix | 1 + synapse/federation/sender/per_destination_queue.py | 183 ++++++++++++++------- synapse/handlers/receipts.py | 1 - tests/federation/test_federation_sender.py | 77 +++++++++ 4 files changed, 198 insertions(+), 64 deletions(-) create mode 100644 changelog.d/14466.bugfix (limited to 'synapse') diff --git a/changelog.d/14466.bugfix b/changelog.d/14466.bugfix new file mode 100644 index 0000000000..82f6e6b68e --- /dev/null +++ b/changelog.d/14466.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where a receipt's thread ID was not sent over federation. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 3ae5e8634c..5af2784f1e 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -35,7 +35,7 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import SynapseTags, set_tag from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt +from synapse.types import JsonDict, ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.visibility import filter_events_for_server @@ -136,8 +136,11 @@ class PerDestinationQueue: # destination self._pending_presence: Dict[str, UserPresenceState] = {} - # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {} + # List of room_id -> receipt_type -> user_id -> receipt_dict, + # + # Each receipt can only have a single receipt per + # (room ID, receipt type, user ID, thread ID) tuple. + self._pending_receipt_edus: List[Dict[str, Dict[str, Dict[str, dict]]]] = [] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -202,17 +205,53 @@ class PerDestinationQueue: Args: receipt: receipt to be queued """ - self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( - receipt.receipt_type, {} - )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} + serialized_receipt: JsonDict = { + "event_ids": receipt.event_ids, + "data": receipt.data, + } + if receipt.thread_id is not None: + serialized_receipt["data"]["thread_id"] = receipt.thread_id + + # Find which EDU to add this receipt to. There's three situations depending + # on the (room ID, receipt type, user, thread ID) tuple: + # + # 1. If it fully matches, clobber the information. + # 2. If it is missing, add the information. + # 3. If the subset tuple of (room ID, receipt type, user) matches, check + # the next EDU (or add a new EDU). + for edu in self._pending_receipt_edus: + receipt_content = edu.setdefault(receipt.room_id, {}).setdefault( + receipt.receipt_type, {} + ) + # If this room ID, receipt type, user ID is not in this EDU, OR if + # the full tuple matches, use the current EDU. + if ( + receipt.user_id not in receipt_content + or receipt_content[receipt.user_id].get("thread_id") + == receipt.thread_id + ): + receipt_content[receipt.user_id] = serialized_receipt + break + + # If no matching EDU was found, create a new one. + else: + self._pending_receipt_edus.append( + { + receipt.room_id: { + receipt.receipt_type: {receipt.user_id: serialized_receipt} + } + } + ) def flush_read_receipts_for_room(self, room_id: str) -> None: - # if we don't have any read-receipts for this room, it may be that we've already - # sent them out, so we don't need to flush. - if room_id not in self._pending_rrs: - return - self._rrs_pending_flush = True - self.attempt_new_transaction() + # If there are any pending receipts for this room then force-flush them + # in a new transaction. + for edu in self._pending_receipt_edus: + if room_id in edu: + self._rrs_pending_flush = True + self.attempt_new_transaction() + # No use in checking remaining EDUs if the room was found. + break def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu @@ -351,7 +390,7 @@ class PerDestinationQueue: self._pending_edus = [] self._pending_edus_keyed = {} self._pending_presence = {} - self._pending_rrs = {} + self._pending_receipt_edus = [] self._start_catching_up() except FederationDeniedError as e: @@ -543,22 +582,27 @@ class PerDestinationQueue: self._destination, last_successful_stream_ordering ) - def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: - if not self._pending_rrs: + def _get_receipt_edus(self, force_flush: bool, limit: int) -> Iterable[Edu]: + if not self._pending_receipt_edus: return if not force_flush and not self._rrs_pending_flush: # not yet time for this lot return - edu = Edu( - origin=self._server_name, - destination=self._destination, - edu_type=EduTypes.RECEIPT, - content=self._pending_rrs, - ) - self._pending_rrs = {} - self._rrs_pending_flush = False - yield edu + # Send at most limit EDUs for receipts. + for content in self._pending_receipt_edus[:limit]: + yield Edu( + origin=self._server_name, + destination=self._destination, + edu_type=EduTypes.RECEIPT, + content=content, + ) + self._pending_receipt_edus = self._pending_receipt_edus[limit:] + + # If there are still pending read-receipts, don't reset the pending flush + # flag. + if not self._pending_receipt_edus: + self._rrs_pending_flush = False def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus @@ -645,27 +689,61 @@ class _TransactionQueueManager: async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: # First we calculate the EDUs we want to send, if any. - # We start by fetching device related EDUs, i.e device updates and to - # device messages. We have to keep 2 free slots for presence and rr_edus. - device_edu_limit = MAX_EDUS_PER_TRANSACTION - 2 + # There's a maximum number of EDUs that can be sent with a transaction, + # generally device updates and to-device messages get priority, but we + # want to ensure that there's room for some other EDUs as well. + # + # This is done by: + # + # * Add a presence EDU, if one exists. + # * Add up-to a small limit of read receipt EDUs. + # * Add to-device EDUs, but leave some space for device list updates. + # * Add device list updates EDUs. + # * If there's any remaining room, add other EDUs. + pending_edus = [] + + # Add presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type=EduTypes.PRESENCE, + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} - # We prioritize to-device messages so that existing encryption channels + # Add read receipt EDUs. + pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5)) + edu_limit = MAX_EDUS_PER_TRANSACTION - len(pending_edus) + + # Next, prioritize to-device messages so that existing encryption channels # work. We also keep a few slots spare (by reducing the limit) so that # we can still trickle out some device list updates. ( to_device_edus, device_stream_id, - ) = await self.queue._get_to_device_message_edus(device_edu_limit - 10) + ) = await self.queue._get_to_device_message_edus(edu_limit - 10) if to_device_edus: self._device_stream_id = device_stream_id else: self.queue._last_device_stream_id = device_stream_id - device_edu_limit -= len(to_device_edus) + pending_edus.extend(to_device_edus) + edu_limit -= len(to_device_edus) + # Add device list update EDUs. device_update_edus, dev_list_id = await self.queue._get_device_update_edus( - device_edu_limit + edu_limit ) if device_update_edus: @@ -673,40 +751,17 @@ class _TransactionQueueManager: else: self.queue._last_device_list_stream_id = dev_list_id - pending_edus = device_update_edus + to_device_edus - - # Now add the read receipt EDU. - pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) - - # And presence EDU. - if self.queue._pending_presence: - pending_edus.append( - Edu( - origin=self.queue._server_name, - destination=self.queue._destination, - edu_type=EduTypes.PRESENCE, - content={ - "push": [ - format_user_presence_state( - presence, self.queue._clock.time_msec() - ) - for presence in self.queue._pending_presence.values() - ] - }, - ) - ) - self.queue._pending_presence = {} + pending_edus.extend(device_update_edus) + edu_limit -= len(device_update_edus) # Finally add any other types of EDUs if there is room. - pending_edus.extend( - self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self.queue._pending_edus_keyed - ): + other_edus = self.queue._pop_pending_edus(edu_limit) + pending_edus.extend(other_edus) + edu_limit -= len(other_edus) + while edu_limit > 0 and self.queue._pending_edus_keyed: _, val = self.queue._pending_edus_keyed.popitem() pending_edus.append(val) + edu_limit -= 1 # Now we look for any PDUs to send, by getting up to 50 PDUs from the # queue @@ -717,8 +772,10 @@ class _TransactionQueueManager: # if we've decided to send a transaction anyway, and we have room, we # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + if edu_limit: + pending_edus.extend( + self.queue._get_receipt_edus(force_flush=True, limit=edu_limit) + ) if self._pdus: self._last_stream_ordering = self._pdus[ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index ac01582442..6a4fed1156 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -92,7 +92,6 @@ class ReceiptsHandler: continue # Check if these receipts apply to a thread. - thread_id = None data = user_values.get("data", {}) thread_id = data.get("thread_id") # If the thread ID is invalid, consider it missing. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index f1e357764f..01f147418b 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -83,6 +83,83 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) + def test_send_receipts_thread(self): + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) + mock_send_transaction.return_value = make_awaitable({}) + + # Create receipts for: + # + # * The same room / user on multiple threads. + # * A different user in the same room. + sender = self.hs.get_federation_sender() + for user, thread in ( + ("alice", None), + ("alice", "thread"), + ("bob", None), + ("bob", "diff-thread"), + ): + receipt = ReadReceipt( + "room_id", + "m.read", + user, + ["event_id"], + thread_id=thread, + data={"ts": 1234}, + ) + self.successResultOf( + defer.ensureDeferred(sender.send_read_receipt(receipt)) + ) + + self.pump() + + # expect a call to send_transaction with two EDUs to separate threads. + mock_send_transaction.assert_called_once() + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + # Note that the ordering of the EDUs doesn't matter. + self.assertCountEqual( + data["edus"], + [ + { + "edu_type": EduTypes.RECEIPT, + "content": { + "room_id": { + "m.read": { + "alice": { + "event_ids": ["event_id"], + "data": {"ts": 1234, "thread_id": "thread"}, + }, + "bob": { + "event_ids": ["event_id"], + "data": {"ts": 1234, "thread_id": "diff-thread"}, + }, + } + } + }, + }, + { + "edu_type": EduTypes.RECEIPT, + "content": { + "room_id": { + "m.read": { + "alice": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, + }, + "bob": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, + }, + } + } + }, + }, + ], + ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but -- cgit 1.5.1 From d56f48038a07fd76d2ce08220a4061f85006bf3b Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 28 Nov 2022 15:25:18 +0000 Subject: Fix logging context warnings due to common usage metrics setup (#14574) `setup()` is run under the sentinel context manager, so we wrap the initial update in a background process. Before this change, Synapse would log two warnings on startup: Starting db txn 'count_daily_users' from sentinel context Starting db connection from sentinel context: metrics will be lost Signed-off-by: Sean Quah --- changelog.d/14574.bugfix | 1 + synapse/metrics/common_usage_metrics.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14574.bugfix (limited to 'synapse') diff --git a/changelog.d/14574.bugfix b/changelog.d/14574.bugfix new file mode 100644 index 0000000000..fac85ec9b0 --- /dev/null +++ b/changelog.d/14574.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.67.0 where two logging context warnings would be logged on startup. diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py index 0a22ea3d92..6e05b043d3 100644 --- a/synapse/metrics/common_usage_metrics.py +++ b/synapse/metrics/common_usage_metrics.py @@ -54,7 +54,9 @@ class CommonUsageMetricsManager: async def setup(self) -> None: """Keep the gauges for common usage metrics up to date.""" - await self._update_gauges() + run_as_background_process( + desc="common_usage_metrics_update_gauges", func=self._update_gauges + ) self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, -- cgit 1.5.1 From 1183c372fa9da01b2667f1b83dab958dad432c68 Mon Sep 17 00:00:00 2001 From: Andrew Ferrazzutti Date: Mon, 28 Nov 2022 11:17:29 -0500 Subject: Use `device_one_time_keys_count` to match MSC3202 (#14565) * Use `device_one_time_keys_count` to match MSC3202 Rename the `device_one_time_key_counts` key in responses to `device_one_time_keys_count` to match the name specified by MSC3202. Also change related variable/class names for consistency. Signed-off-by: Andrew Ferrazzutti * Update changelog.d/14565.misc * Revert name change for `one_time_key_counts` key as this is a different key altogether from `device_one_time_keys_count`, which is used for `/sync` instead of appservice transactions. Signed-off-by: Andrew Ferrazzutti --- changelog.d/14565.misc | 1 + synapse/appservice/__init__.py | 10 +++++----- synapse/appservice/api.py | 11 +++++++---- synapse/appservice/scheduler.py | 16 ++++++++-------- synapse/handlers/sync.py | 6 +++--- synapse/storage/databases/main/appservice.py | 10 +++++----- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- tests/appservice/test_scheduler.py | 6 +++--- tests/handlers/test_appservice.py | 4 ++-- 9 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 changelog.d/14565.misc (limited to 'synapse') diff --git a/changelog.d/14565.misc b/changelog.d/14565.misc new file mode 100644 index 0000000000..19a62b036c --- /dev/null +++ b/changelog.d/14565.misc @@ -0,0 +1 @@ +In application service transactions that include the experimental `org.matrix.msc3202.device_one_time_key_counts` key, include a duplicate key of `org.matrix.msc3202.device_one_time_keys_count` to match the name proposed by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/blob/travis/msc/otk-dl-appservice/proposals/3202-encrypted-appservices.md). diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 500bdde3a9..bf4e6c629b 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -32,9 +32,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Type for the `device_one_time_key_counts` field in an appservice transaction +# Type for the `device_one_time_keys_count` field in an appservice transaction # user ID -> {device ID -> {algorithm -> count}} -TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] +TransactionOneTimeKeysCount = Dict[str, Dict[str, Dict[str, int]]] # Type for the `device_unused_fallback_key_types` field in an appservice transaction # user ID -> {device ID -> [algorithm]} @@ -376,7 +376,7 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ): @@ -385,7 +385,7 @@ class AppServiceTransaction: self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages - self.one_time_key_counts = one_time_key_counts + self.one_time_keys_count = one_time_keys_count self.unused_fallback_keys = unused_fallback_keys self.device_list_summary = device_list_summary @@ -402,7 +402,7 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, - one_time_key_counts=self.one_time_key_counts, + one_time_keys_count=self.one_time_keys_count, unused_fallback_keys=self.unused_fallback_keys, device_list_summary=self.device_list_summary, txn_id=self.id, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 60774b240d..edafd433cd 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.events import EventBase @@ -262,7 +262,7 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, @@ -310,10 +310,13 @@ class ApplicationServiceApi(SimpleHttpClient): # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: - if one_time_key_counts: + if one_time_keys_count: body[ "org.matrix.msc3202.device_one_time_key_counts" - ] = one_time_key_counts + ] = one_time_keys_count + body[ + "org.matrix.msc3202.device_one_time_keys_count" + ] = one_time_keys_count if unused_fallback_keys: body[ "org.matrix.msc3202.device_unused_fallback_key_types" diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 430ffbcd1f..7b562795a3 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -64,7 +64,7 @@ from typing import ( from synapse.appservice import ( ApplicationService, ApplicationServiceState, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.appservice.api import ApplicationServiceApi @@ -258,7 +258,7 @@ class _ServiceQueuer: ): return - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None if ( @@ -269,7 +269,7 @@ class _ServiceQueuer: # for the users which are mentioned in this transaction, # as well as the appservice's sender. ( - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, ) = await self._compute_msc3202_otk_counts_and_fallback_keys( service, events, ephemeral, to_device_messages_to_send @@ -281,7 +281,7 @@ class _ServiceQueuer: events, ephemeral, to_device_messages_to_send, - one_time_key_counts, + one_time_keys_count, unused_fallback_keys, device_list_summary, ) @@ -296,7 +296,7 @@ class _ServiceQueuer: events: Iterable[EventBase], ephemerals: Iterable[JsonDict], to_device_messages: Iterable[JsonDict], - ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, - first computes a list of application services users that may have @@ -367,7 +367,7 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, - one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: @@ -380,7 +380,7 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -397,7 +397,7 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], - one_time_key_counts=one_time_key_counts or {}, + one_time_keys_count=one_time_keys_count or {}, unused_fallback_keys=unused_fallback_keys or {}, device_list_summary=device_list_summary or DeviceListUpdates(), ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 259456b55d..c8858b22dd 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1426,14 +1426,14 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_key_counts: JsonDict = {} + one_time_keys_count: JsonDict = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: # * no change in OTK count since the provided since token # * the server has zero OTKs left for this device # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 - one_time_key_counts = await self.store.count_e2e_one_time_keys( + one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) unused_fallback_key_types = ( @@ -1463,7 +1463,7 @@ class SyncHandler: archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, - device_one_time_keys_count=one_time_key_counts, + device_one_time_keys_count=one_time_keys_count, device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 25da0c56c5..c2c8018ee2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,7 +20,7 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices @@ -260,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: @@ -273,7 +273,7 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -299,7 +299,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, - one_time_key_counts=one_time_key_counts, + one_time_keys_count=one_time_keys_count, unused_fallback_keys=unused_fallback_keys, device_list_summary=device_list_summary, ) @@ -379,7 +379,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index cf33e73e2b..643c47d608 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -33,7 +33,7 @@ from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace @@ -514,7 +514,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def count_bulk_e2e_one_time_keys_for_as( self, user_ids: Collection[str] - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: """ Counts, in bulk, the one-time keys for all the users specified. Intended to be used by application services for populating OTK counts in @@ -528,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _count_bulk_e2e_one_time_keys_txn( txn: LoggingTransaction, - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "user_id", user_ids ) @@ -541,7 +541,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ txn.execute(sql, user_parameters) - result: TransactionOneTimeKeyCounts = {} + result: TransactionOneTimeKeysCount = {} for user_id, device_id, algorithm, count in txn: # We deliberately construct empty dictionaries for diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0b22afdc75..0a1ae83a2b 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -69,7 +69,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -96,7 +96,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -125,7 +125,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 144e49d0fd..9ed26d87a7 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -25,7 +25,7 @@ import synapse.storage from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.handlers.appservice import ApplicationServicesHandler @@ -1123,7 +1123,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Capture what was sent as an AS transaction. self.send_mock.assert_called() last_args, _last_kwargs = self.send_mock.call_args - otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS] unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ self.ARG_FALLBACK_KEYS ] -- cgit 1.5.1 From 8f10c8b054fc970838be9ae6f1f5aea95f166c98 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 28 Nov 2022 15:54:18 -0600 Subject: Move MSC3030 `/timestamp_to_event` endpoint to stable v1 location (#14471) Fix https://github.com/matrix-org/synapse/issues/14390 - Client API: `/_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir=` -> `/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=` - Federation API: `/_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir=` -> `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=` Complement test changes: https://github.com/matrix-org/complement/pull/559 --- changelog.d/14471.feature | 1 + docker/complement/conf/workers-shared-extra.yaml.j2 | 2 -- docker/configure_workers_and_start.py | 2 ++ docs/workers.md | 2 ++ scripts-dev/complement.sh | 6 +++--- synapse/config/experimental.py | 3 --- synapse/federation/federation_client.py | 12 +++++++++++- synapse/federation/transport/client.py | 5 ++--- synapse/federation/transport/server/__init__.py | 8 -------- synapse/federation/transport/server/federation.py | 3 +-- synapse/rest/client/room.py | 10 +++------- synapse/rest/client/versions.py | 2 -- tests/rest/client/test_rooms.py | 7 +------ 13 files changed, 26 insertions(+), 37 deletions(-) create mode 100644 changelog.d/14471.feature (limited to 'synapse') diff --git a/changelog.d/14471.feature b/changelog.d/14471.feature new file mode 100644 index 0000000000..a0e0c74f1a --- /dev/null +++ b/changelog.d/14471.feature @@ -0,0 +1 @@ +Move MSC3030 `/timestamp_to_event` endpoints to stable `v1` location (`/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=`, `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=`). diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 883a87159c..ca640c343b 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -100,8 +100,6 @@ experimental_features: # client-side support for partial state in /send_join responses faster_joins: true {% endif %} - # Enable jump to date endpoint - msc3030_enabled: true # Filtering /messages by relation type. msc3874_enabled: true diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index c1e1544536..58c62f2231 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -140,6 +140,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event", "^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms", "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", + "^/_matrix/client/v1/rooms/.*/timestamp_to_event$", "^/_matrix/client/(api/v1|r0|v3|unstable)/search", ], "shared_extra_conf": {}, @@ -163,6 +164,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/federation/(v1|v2)/invite/", "^/_matrix/federation/(v1|v2)/query_auth/", "^/_matrix/federation/(v1|v2)/event_auth/", + "^/_matrix/federation/v1/timestamp_to_event/", "^/_matrix/federation/(v1|v2)/exchange_third_party_invite/", "^/_matrix/federation/(v1|v2)/user/devices/", "^/_matrix/federation/(v1|v2)/get_groups_publicised$", diff --git a/docs/workers.md b/docs/workers.md index 27e54c5846..2b65acb5ed 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -191,6 +191,7 @@ information. ^/_matrix/federation/(v1|v2)/send_leave/ ^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/v1/event_auth/ + ^/_matrix/federation/v1/timestamp_to_event/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ ^/_matrix/key/v2/query @@ -218,6 +219,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ + ^/_matrix/client/v1/rooms/.*/timestamp_to_event$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Encryption requests diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 803c6ce92d..7744b47097 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -162,9 +162,9 @@ else # We only test faster room joins on monoliths, because they are purposefully # being developed without worker support to start with. # - # The tests for importing historical messages (MSC2716) and jump to date (MSC3030) - # also only pass with monoliths, currently. - test_tags="$test_tags,faster_joins,msc2716,msc3030" + # The tests for importing historical messages (MSC2716) also only pass with monoliths, + # currently. + test_tags="$test_tags,faster_joins,msc2716" fi diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d4b71d1673..a503abf364 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -53,9 +53,6 @@ class ExperimentalConfig(Config): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) - # MSC3030 (Jump to date API endpoint) - self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c4c0bc7315..8bccc9c60d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1691,9 +1691,19 @@ class FederationClient(FederationBase): # to return events on *both* sides of the timestamp to # help reconcile the gap faster. _timestamp_to_event_from_destination, + # Since this endpoint is new, we should try other servers before giving up. + # We can safely remove this in a year (remove after 2023-11-16). + failover_on_unknown_endpoint=True, ) return timestamp_to_event_response - except SynapseError: + except SynapseError as e: + logger.warn( + "timestamp_to_event(room_id=%s, timestamp=%s, direction=%s): encountered error when trying to fetch from destinations: %s", + room_id, + timestamp, + direction, + e, + ) return None async def _timestamp_to_event_from_destination( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index a3cfc701cd..77f1f39cac 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -185,9 +185,8 @@ class TransportLayerClient: Raises: Various exceptions when the request fails """ - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, - "/org.matrix.msc3030/timestamp_to_event/%s", + path = _create_v1_path( + "/timestamp_to_event/%s", room_id, ) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 50623cd385..2725f53cf6 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -25,7 +25,6 @@ from synapse.federation.transport.server._base import ( from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, - FederationTimestampLookupServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -291,13 +290,6 @@ def register_servlets( ) for servletclass in SERVLET_GROUPS[servlet_group]: - # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled - if ( - servletclass == FederationTimestampLookupServlet - and not hs.config.experimental.msc3030_enabled - ): - continue - # Only allow the `/account_status` servlet if msc3720 is enabled if ( servletclass == FederationAccountStatusServlet diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 205fd16daa..53e77b4bb6 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -218,14 +218,13 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= + GET /_matrix/federation/v1/timestamp_to_event/?ts=&dir= { "event_id": ... } """ PATH = "/timestamp_to_event/(?P[^/]*)/?" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" async def on_GET( self, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 91cb791139..636cc62877 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1284,17 +1284,14 @@ class TimestampLookupRestServlet(RestServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= + GET /_matrix/client/v1/rooms//timestamp_to_event?ts=&dir= { "event_id": ... } """ PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3030" - "/rooms/(?P[^/]*)/timestamp_to_event$" - ), + re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/timestamp_to_event$"), ) def __init__(self, hs: "HomeServer"): @@ -1421,8 +1418,7 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) - if hs.config.experimental.msc3030_enabled: - TimestampLookupRestServlet(hs).register(http_server) + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 180a11ef88..3c0a90010b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,8 +101,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 "org.matrix.msc2716": self.config.experimental.msc2716_enabled, - # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 - "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e919e089cb..b4daace556 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -3546,11 +3546,6 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = {"msc3030_enabled": True} - return config - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._storage_controllers = self.hs.get_storage_controllers() @@ -3592,7 +3587,7 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + f"/_matrix/client/v1/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", access_token=self.room_owner_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) -- cgit 1.5.1 From 3da645032722fbf09c1e5efbc51d8c5c78d8a2cd Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 28 Nov 2022 16:29:53 -0700 Subject: Initial support for MSC3931: Room version push rule feature flags (#14520) * Add support for MSC3931: Room Version Supports push rule condition * Create experimental flag for future work, and use it to gate MSC3931 * Changelog entry --- changelog.d/14520.feature | 1 + rust/src/push/evaluator.rs | 26 ++++++++++++++++++++++++++ rust/src/push/mod.rs | 16 ++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 2 ++ synapse/api/room_versions.py | 21 ++++++++++++++++++++- synapse/config/experimental.py | 3 +++ synapse/push/bulk_push_rule_evaluator.py | 6 ++++++ tests/push/test_push_rule_evaluator.py | 2 ++ 8 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14520.feature (limited to 'synapse') diff --git a/changelog.d/14520.feature b/changelog.d/14520.feature new file mode 100644 index 0000000000..210acaa8ee --- /dev/null +++ b/changelog.d/14520.feature @@ -0,0 +1 @@ +Add unstable support for an Extensible Events room version (`org.matrix.msc1767.10`) via [MSC1767](https://github.com/matrix-org/matrix-spec-proposals/pull/1767), [MSC3931](https://github.com/matrix-org/matrix-spec-proposals/pull/3931), [MSC3932](https://github.com/matrix-org/matrix-spec-proposals/pull/3932), and [MSC3933](https://github.com/matrix-org/matrix-spec-proposals/pull/3933). \ No newline at end of file diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index cedd42c54d..e8e3d604ee 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -29,6 +29,10 @@ use super::{ lazy_static! { /// Used to parse the `is` clause in the room member count condition. static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex"); + + /// Used to determine which MSC3931 room version feature flags are actually known to + /// the push evaluator. + static ref KNOWN_RVER_FLAGS: Vec = vec![]; } /// Allows running a set of push rules against a particular event. @@ -57,6 +61,13 @@ pub struct PushRuleEvaluator { /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, + + /// If MSC3931 is applicable, the feature flags for the room version. + room_version_feature_flags: Vec, + + /// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same + /// flag as MSC1767 (extensible events core). + msc3931_enabled: bool, } #[pymethods] @@ -70,6 +81,8 @@ impl PushRuleEvaluator { notification_power_levels: BTreeMap, related_events_flattened: BTreeMap>, related_event_match_enabled: bool, + room_version_feature_flags: Vec, + msc3931_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -84,6 +97,8 @@ impl PushRuleEvaluator { sender_power_level, related_events_flattened, related_event_match_enabled, + room_version_feature_flags, + msc3931_enabled, }) } @@ -204,6 +219,15 @@ impl PushRuleEvaluator { false } } + KnownCondition::RoomVersionSupports { feature } => { + if !self.msc3931_enabled { + false + } else { + let flag = feature.to_string(); + KNOWN_RVER_FLAGS.contains(&flag) + && self.room_version_feature_flags.contains(&flag) + } + } }; Ok(result) @@ -362,6 +386,8 @@ fn push_rule_evaluator() { BTreeMap::new(), BTreeMap::new(), true, + vec![], + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index d57800aa4a..eef39f6472 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -277,6 +277,10 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, + #[serde(rename = "org.matrix.msc3931.room_version_supports")] + RoomVersionSupports { + feature: Cow<'static, str>, + }, } impl IntoPy for Condition { @@ -491,6 +495,18 @@ fn test_deserialize_unstable_msc3664_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3931_condition() { + let json = + r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index ceade65ef9..cbeb49663c 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -41,6 +41,8 @@ class PushRuleEvaluator: notification_power_levels: Mapping[str, int], related_events_flattened: Mapping[str, Mapping[str, str]], related_event_match_enabled: bool, + room_version_feature_flags: list[str], + msc3931_enabled: bool, ): ... def run( self, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index e37acb0f1e..1bd1ef3e2b 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional +from typing import Callable, Dict, List, Optional import attr @@ -91,6 +91,12 @@ class RoomVersion: msc3787_knock_restricted_join_rule: bool # MSC3667: Enforce integer power levels msc3667_int_only_power_levels: bool + # MSC3931: Adds a push rule condition for "room version feature flags", making + # some push rules room version dependent. Note that adding a flag to this list + # is not enough to mark it "supported": the push rule evaluator also needs to + # support the flag. Unknown flags are ignored by the evaluator, making conditions + # fail if used. + msc3931_push_features: List[str] class RoomVersions: @@ -111,6 +117,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V2 = RoomVersion( "2", @@ -129,6 +136,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V3 = RoomVersion( "3", @@ -147,6 +155,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V4 = RoomVersion( "4", @@ -165,6 +174,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V5 = RoomVersion( "5", @@ -183,6 +193,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V6 = RoomVersion( "6", @@ -201,6 +212,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -219,6 +231,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V7 = RoomVersion( "7", @@ -237,6 +250,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V8 = RoomVersion( "8", @@ -255,6 +269,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V9 = RoomVersion( "9", @@ -273,6 +288,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) MSC3787 = RoomVersion( "org.matrix.msc3787", @@ -291,6 +307,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) V10 = RoomVersion( "10", @@ -309,6 +326,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=True, + msc3931_push_features=[], ) MSC2716v4 = RoomVersion( "org.matrix.msc2716v4", @@ -327,6 +345,7 @@ class RoomVersions: msc2716_redactions=True, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, + msc3931_push_features=[], ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index a503abf364..b3f51fc57d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): # MSC3912: Relation-based redactions. self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) + + # MSC1767 and friends: Extensible Events + self.msc1767_enabled: bool = experimental.get("msc1767_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 75b7e126ca..9cc3da6d91 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -338,6 +338,10 @@ class BulkPushRuleEvaluator: for user_id, level in notification_levels.items(): notification_levels[user_id] = int(level) + room_version_features = event.room_version.msc3931_push_features + if not room_version_features: + room_version_features = [] + evaluator = PushRuleEvaluator( _flatten_dict(event), room_member_count, @@ -345,6 +349,8 @@ class BulkPushRuleEvaluator: notification_levels, related_events, self._related_event_match_enabled, + room_version_features, + self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) users = rules_by_user.keys() diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index fe7c145840..5ababe6a39 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -62,6 +62,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): power_levels.get("notifications", {}), {} if related_events is None else related_events, True, + event.room_version.msc3931_push_features, + True, ) def test_display_name(self) -> None: -- cgit 1.5.1 From dd518281208d2fc446f9995ad78949e807d8f5b8 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 28 Nov 2022 17:22:34 -0700 Subject: Create MSC1767 (extensible events) room version; Implement MSC3932 (#14521) * Add MSC1767's dedicated room version, based on v10 * Only enable MSC1767 room version if the config flag is on Using a similar technique to knocking: https://github.com/matrix-org/synapse/pull/6739/files#diff-3af529eedb0e00279bafb7369370c9654b37792af8eafa0925400e9281d57f0a * Support MSC3932: Extensible events room version feature flag * Changelog entry --- changelog.d/14521.feature | 1 + rust/src/push/evaluator.rs | 97 +++++++++++++++++++++++++++++++++++++++++- synapse/api/room_versions.py | 29 ++++++++++++- synapse/config/experimental.py | 5 +++ 4 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14521.feature (limited to 'synapse') diff --git a/changelog.d/14521.feature b/changelog.d/14521.feature new file mode 100644 index 0000000000..210acaa8ee --- /dev/null +++ b/changelog.d/14521.feature @@ -0,0 +1 @@ +Add unstable support for an Extensible Events room version (`org.matrix.msc1767.10`) via [MSC1767](https://github.com/matrix-org/matrix-spec-proposals/pull/1767), [MSC3931](https://github.com/matrix-org/matrix-spec-proposals/pull/3931), [MSC3932](https://github.com/matrix-org/matrix-spec-proposals/pull/3932), and [MSC3933](https://github.com/matrix-org/matrix-spec-proposals/pull/3933). \ No newline at end of file diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index e8e3d604ee..b4c3039aba 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Cow; use std::collections::BTreeMap; +use crate::push::{PushRule, PushRules}; use anyhow::{Context, Error}; use lazy_static::lazy_static; use log::warn; @@ -32,7 +34,30 @@ lazy_static! { /// Used to determine which MSC3931 room version feature flags are actually known to /// the push evaluator. - static ref KNOWN_RVER_FLAGS: Vec = vec![]; + static ref KNOWN_RVER_FLAGS: Vec = vec![ + RoomVersionFeatures::ExtensibleEvents.as_str().to_string(), + ]; + + /// The "safe" rule IDs which are not affected by MSC3932's behaviour (room versions which + /// declare Extensible Events support ultimately *disable* push rules which do not declare + /// *any* MSC3931 room_version_supports condition). + static ref SAFE_EXTENSIBLE_EVENTS_RULE_IDS: Vec = vec![ + "global/override/.m.rule.master".to_string(), + "global/override/.m.rule.roomnotif".to_string(), + "global/content/.m.rule.contains_user_name".to_string(), + ]; +} + +enum RoomVersionFeatures { + ExtensibleEvents, +} + +impl RoomVersionFeatures { + fn as_str(&self) -> &'static str { + match self { + RoomVersionFeatures::ExtensibleEvents => "org.matrix.msc3932.extensible_events", + } + } } /// Allows running a set of push rules against a particular event. @@ -121,7 +146,22 @@ impl PushRuleEvaluator { continue; } + let rule_id = &push_rule.rule_id().to_string(); + let extev_flag = &RoomVersionFeatures::ExtensibleEvents.as_str().to_string(); + let supports_extensible_events = self.room_version_feature_flags.contains(extev_flag); + let safe_from_rver_condition = SAFE_EXTENSIBLE_EVENTS_RULE_IDS.contains(rule_id); + let mut has_rver_condition = false; + for condition in push_rule.conditions.iter() { + has_rver_condition = has_rver_condition + || match condition { + Condition::Known(known) => match known { + // per MSC3932, we just need *any* room version condition to match + KnownCondition::RoomVersionSupports { feature: _ } => true, + _ => false, + }, + _ => false, + }; match self.match_condition(condition, user_id, display_name) { Ok(true) => {} Ok(false) => continue 'outer, @@ -132,6 +172,13 @@ impl PushRuleEvaluator { } } + // MSC3932: Disable push rules in extensible event-supporting room versions if they + // don't describe *any* MSC3931 room version condition, unless the rule is on the + // safe list. + if !has_rver_condition && !safe_from_rver_condition && supports_extensible_events { + continue; + } + let actions = push_rule .actions .iter() @@ -394,3 +441,51 @@ fn push_rule_evaluator() { let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); } + +#[test] +fn test_requires_room_version_supports_condition() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + false, + flags, + true, + ) + .unwrap(); + + // first test: are the master and contains_user_name rules excluded from the "requires room + // version condition" check? + let mut result = evaluator.run( + &FilteredPushRules::default(), + Some("@bob:example.org"), + None, + ); + assert_eq!(result.len(), 3); + + // second test: if an appropriate push rule is in play, does it get handled? + let custom_rule = PushRule { + rule_id: Cow::from("global/underride/.org.example.extensible"), + priority_class: 1, // underride + conditions: Cow::from(vec![Condition::Known( + KnownCondition::RoomVersionSupports { + feature: Cow::from(RoomVersionFeatures::ExtensibleEvents.as_str().to_string()), + }, + )]), + actions: Cow::from(vec![Action::Notify]), + default: false, + default_enabled: true, + }; + let rules = PushRules::new(vec![custom_rule]); + result = evaluator.run( + &FilteredPushRules::py_new(rules, BTreeMap::new(), true), + None, + None, + ); + assert_eq!(result.len(), 1); +} diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 1bd1ef3e2b..ac62011c9f 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -51,6 +51,13 @@ class RoomDisposition: UNSTABLE = "unstable" +class PushRuleRoomFlag: + """Enum for listing possible MSC3931 room version feature flags, for push rules""" + + # MSC3932: Room version supports MSC1767 Extensible Events. + EXTENSIBLE_EVENTS = "org.matrix.msc3932.extensible_events" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class RoomVersion: """An object which describes the unique attributes of a room version.""" @@ -96,7 +103,7 @@ class RoomVersion: # is not enough to mark it "supported": the push rule evaluator also needs to # support the flag. Unknown flags are ignored by the evaluator, making conditions # fail if used. - msc3931_push_features: List[str] + msc3931_push_features: List[str] # values from PushRuleRoomFlag class RoomVersions: @@ -347,6 +354,26 @@ class RoomVersions: msc3667_int_only_power_levels=False, msc3931_push_features=[], ) + MSC1767v10 = RoomVersion( + # MSC1767 (Extensible Events) based on room version "10" + "org.matrix.msc1767.10", + RoomDisposition.UNSTABLE, + EventFormatVersions.ROOM_V4_PLUS, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc3375_redaction_rules=True, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, + msc3787_knock_restricted_join_rule=True, + msc3667_int_only_power_levels=True, + msc3931_push_features=[PushRuleRoomFlag.EXTENSIBLE_EVENTS], + ) KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b3f51fc57d..573fa0386f 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -16,6 +16,7 @@ from typing import Any, Optional import attr +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config._base import Config from synapse.types import JsonDict @@ -131,3 +132,7 @@ class ExperimentalConfig(Config): # MSC1767 and friends: Extensible Events self.msc1767_enabled: bool = experimental.get("msc1767_enabled", False) + if self.msc1767_enabled: + # Enable room version (and thus applicable push rules from MSC3931/3932) + version_id = RoomVersions.MSC1767v10.identifier + KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 -- cgit 1.5.1 From 9ccc09fe9e332a71b8cf5bf42b16f6acf5a6887d Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 28 Nov 2022 18:02:41 -0700 Subject: Support MSC1767's `content.body` behaviour; Add base rules from MSC3933 (#14524) * Support MSC1767's `content.body` behaviour in push rules * Add the base rules from MSC3933 * Changelog entry * Flip condition around for finding `m.markup` * Remove forgotten import --- changelog.d/14524.feature | 1 + rust/src/push/base_rules.rs | 270 ++++++++++++++++++++++++++++ rust/src/push/evaluator.rs | 2 +- rust/src/push/mod.rs | 7 + stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/push/bulk_push_rule_evaluator.py | 29 ++- synapse/storage/databases/main/push_rule.py | 5 +- 7 files changed, 316 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14524.feature (limited to 'synapse') diff --git a/changelog.d/14524.feature b/changelog.d/14524.feature new file mode 100644 index 0000000000..210acaa8ee --- /dev/null +++ b/changelog.d/14524.feature @@ -0,0 +1 @@ +Add unstable support for an Extensible Events room version (`org.matrix.msc1767.10`) via [MSC1767](https://github.com/matrix-org/matrix-spec-proposals/pull/1767), [MSC3931](https://github.com/matrix-org/matrix-spec-proposals/pull/3931), [MSC3932](https://github.com/matrix-org/matrix-spec-proposals/pull/3932), and [MSC3933](https://github.com/matrix-org/matrix-spec-proposals/pull/3933). \ No newline at end of file diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 49802fa4eb..35129691ca 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -274,6 +274,156 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.encrypted_room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.encrypted")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.message.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.message")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.file.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.file")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.image.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.image")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.video.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.video")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed( + "global/underride/.org.matrix.msc3933.rule.extensible.audio.room_one_to_one", + ), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("org.matrix.msc1767.audio")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, @@ -302,6 +452,126 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.encrypted"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.encrypted")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.message"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.message")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.file"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.file")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.image"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.image")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.video"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.video")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc1767.rule.extensible.audio"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + // MSC3933: Type changed from template rule - see MSC. + pattern: Some(Cow::Borrowed("m.audio")), + pattern_type: None, + })), + // MSC3933: Add condition on top of template rule - see MSC. + Condition::Known(KnownCondition::RoomVersionSupports { + // RoomVersionFeatures::ExtensibleEvents.as_str(), ideally + feature: Cow::Borrowed("org.matrix.msc3932.extensible_events"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.im.vector.jitsi"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index b4c3039aba..1cd54f7e2c 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -483,7 +483,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, true), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index eef39f6472..2e9d3e38a1 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -412,6 +412,7 @@ pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, msc3664_enabled: bool, + msc1767_enabled: bool, } #[pymethods] @@ -421,11 +422,13 @@ impl FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, msc3664_enabled: bool, + msc1767_enabled: bool, ) -> Self { Self { push_rules, enabled_map, msc3664_enabled, + msc1767_enabled, } } @@ -450,6 +453,10 @@ impl FilteredPushRules { return false; } + if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + return false; + } + true }) .map(|r| { diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index cbeb49663c..a6a586a0b5 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -26,7 +26,11 @@ class PushRules: class FilteredPushRules: def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + self, + push_rules: PushRules, + enabled_map: Dict[str, bool], + msc3664_enabled: bool, + msc1767_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9cc3da6d91..d6b377860f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -29,6 +29,7 @@ from typing import ( from prometheus_client import Counter from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes +from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -343,7 +344,7 @@ class BulkPushRuleEvaluator: room_version_features = [] evaluator = PushRuleEvaluator( - _flatten_dict(event), + _flatten_dict(event, room_version=event.room_version), room_member_count, sender_power_level, notification_levels, @@ -426,6 +427,7 @@ StateGroup = Union[object, int] def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], + room_version: Optional[RoomVersion] = None, prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -437,6 +439,31 @@ def _flatten_dict( if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() elif isinstance(value, Mapping): + # do not set `room_version` due to recursion considerations below _flatten_dict(value, prefix=(prefix + [key]), result=result) + # `room_version` should only ever be set when looking at the top level of an event + if ( + room_version is not None + and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features + and isinstance(d, EventBase) + ): + # Room supports extensible events: replace `content.body` with the plain text + # representation from `m.markup`, as per MSC1767. + markup = d.get("content").get("m.markup") + if room_version.identifier.startswith("org.matrix.msc1767."): + markup = d.get("content").get("org.matrix.msc1767.markup") + if markup is not None and isinstance(markup, list): + text = "" + for rep in markup: + if not isinstance(rep, dict): + # invalid markup - skip all processing + break + if rep.get("mimetype", "text/plain") == "text/plain": + rep_text = rep.get("body") + if rep_text is not None and isinstance(rep_text, str): + text = rep_text.lower() + break + result["content.body"] = text + return result diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 12ad44dbb3..d4c64c46ad 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -84,7 +84,10 @@ def _load_rules( push_rules = PushRules(ruleslist) filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + push_rules, + enabled_map, + msc3664_enabled=experimental_config.msc3664_enabled, + msc1767_enabled=experimental_config.msc1767_enabled, ) return filtered_rules -- cgit 1.5.1 From 72f3e381375ba10d576a23025ca312397114de6b Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 28 Nov 2022 19:18:12 -0800 Subject: Fix possible variable shadow in `create_new_client_event` (#14575) --- changelog.d/14575.misc | 1 + synapse/handlers/message.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14575.misc (limited to 'synapse') diff --git a/changelog.d/14575.misc b/changelog.d/14575.misc new file mode 100644 index 0000000000..f6fa54eaa2 --- /dev/null +++ b/changelog.d/14575.misc @@ -0,0 +1 @@ +Fix a possible variable shadow in `create_new_client_event`. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4cf593cfdc..5cbe89f4fd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1135,11 +1135,13 @@ class EventCreationHandler: ) state_events = await self.store.get_events_as_list(state_event_ids) # Create a StateMap[str] - state_map = {(e.type, e.state_key): e.event_id for e in state_events} + current_state_ids = { + (e.type, e.state_key): e.event_id for e in state_events + } # Actually strip down and only use the necessary auth events auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, - current_state_ids=state_map, + current_state_ids=current_state_ids, for_verification=False, ) -- cgit 1.5.1 From c7e29ca277cf60bfdc488b93f4321b046fa6b46f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Nov 2022 10:36:41 +0000 Subject: POC delete stale non-e2e devices for users (#14038) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. Co-authored-by: Patrick Cloke Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14038.misc | 1 + synapse/handlers/device.py | 13 +++++- synapse/storage/databases/main/devices.py | 67 ++++++++++++++++++++++++++++++- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14038.misc (limited to 'synapse') diff --git a/changelog.d/14038.misc b/changelog.d/14038.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14038.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b1e55e1b9e..7c4dd8cf5a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -421,6 +421,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Prune the user's device list if they already have a lot of devices. + await self._prune_too_many_devices(user_id) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -452,6 +455,14 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + await self.delete_devices(user_id, device_ids) + async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -481,7 +492,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 534f7fc04a..1e83c62753 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1533,6 +1533,70 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return () + + # We prune everything older than N days. + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + if num_devices > 50: + # If the user has more than 50 devices, then we chose a last seen + # that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Now fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> Collection[str]: + txn.execute(sql, (user_id, max_last_seen)) + return {device_id for device_id, in txn} + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1591,6 +1655,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1636,7 +1701,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Deletes several devices. Args: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 49ad3c1324..a9af1babed 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,6 +169,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -189,7 +191,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1 From e860316818da4bd643d567708adb8d104f4a3351 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 29 Nov 2022 13:05:07 +0000 Subject: Fix `UndefinedColumn: column "key_json" does not exist` errors when handling users with more than 50 non-E2E devices (#14580) --- synapse/storage/databases/main/devices.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse') diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1e83c62753..0378035cff 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1559,6 +1559,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # that ensures we keep at most 50 devices. sql = """ SELECT last_seen FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) WHERE user_id = ? AND NOT hidden -- cgit 1.5.1 From 13aa29db1ddc925beb35f5f1da8fd1a1bcc91373 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Nov 2022 10:49:23 -0500 Subject: Advertise support for Matrix v1.5. (#14576) All features of Matrix v1.5 were already supported: this was mostly a maintenance release. --- changelog.d/14576.feature | 1 + synapse/rest/client/versions.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/14576.feature (limited to 'synapse') diff --git a/changelog.d/14576.feature b/changelog.d/14576.feature new file mode 100644 index 0000000000..4fe8cb2667 --- /dev/null +++ b/changelog.d/14576.feature @@ -0,0 +1 @@ +Advertise support for Matrix 1.5 on `/_matrix/client/versions`. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 3c0a90010b..e19c0946c0 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -77,6 +77,7 @@ class VersionsRestServlet(RestServlet): "v1.2", "v1.3", "v1.4", + "v1.5", ], # as per MSC1497: "unstable_features": { -- cgit 1.5.1 From c29e2c630624beb0b5557aa0f7ccdcedbe62def1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 29 Nov 2022 17:48:48 +0000 Subject: Revert "POC delete stale non-e2e devices for users (#14038)" (#14582) --- changelog.d/14582.bugfix | 1 + synapse/handlers/device.py | 13 +----- synapse/storage/databases/main/devices.py | 68 +------------------------------ tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 5 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14582.bugfix (limited to 'synapse') diff --git a/changelog.d/14582.bugfix b/changelog.d/14582.bugfix new file mode 100644 index 0000000000..caad468e70 --- /dev/null +++ b/changelog.d/14582.bugfix @@ -0,0 +1 @@ +Fix a regression in Synapse 1.73.0rc1 where Synapse's main process would stop responding to HTTP requests when a user with a large number of devices logs in. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7c4dd8cf5a..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -421,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) - # Prune the user's device list if they already have a lot of devices. - await self._prune_too_many_devices(user_id) - if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -455,14 +452,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _prune_too_many_devices(self, user_id: str) -> None: - """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id) - if not device_ids: - return - - await self.delete_devices(user_id, device_ids) - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -492,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 0378035cff..534f7fc04a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1533,71 +1533,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]: - """Check if the user has a lot of devices, and if so return the set of - devices we can prune. - - This does *not* return hidden devices or devices with E2E keys. - """ - - num_devices = await self.db_pool.simple_select_one_onecol( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcol="COALESCE(COUNT(*), 0)", - desc="count_devices", - ) - - # We let users have up to ten devices without pruning. - if num_devices <= 10: - return () - - # We prune everything older than N days. - max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 - - if num_devices > 50: - # If the user has more than 50 devices, then we chose a last seen - # that ensures we keep at most 50 devices. - sql = """ - SELECT last_seen FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen IS NOT NULL - AND key_json IS NULL - ORDER BY last_seen DESC - LIMIT 1 - OFFSET 50 - """ - - rows = await self.db_pool.execute( - "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) - ) - if rows: - max_last_seen = max(rows[0][0], max_last_seen) - - # Now fetch the devices to delete. - sql = """ - SELECT DISTINCT device_id FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen < ? - AND key_json IS NULL - """ - - def check_too_many_devices_for_user_txn( - txn: LoggingTransaction, - ) -> Collection[str]: - txn.execute(sql, (user_id, max_last_seen)) - return {device_id for device_id, in txn} - - return await self.db_pool.runInteraction( - "check_too_many_devices_for_user", - check_too_many_devices_for_user_txn, - ) - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1656,7 +1591,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, - "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1702,7 +1636,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a456bffd63..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": 1000000, + "last_seen_ts": None, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index a9af1babed..49ad3c1324 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,8 +169,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) - last_seen = self.clock.time_msec() - if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -191,7 +189,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": last_seen, + "last_seen": None, }, ], ) -- cgit 1.5.1 From ecb6fe9d9cf8375b760eb727be0e1dec3612e026 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 30 Nov 2022 11:59:57 +0000 Subject: Stop using deprecated `keyIds` param on /key/v2/server (#14525) Fixes #14523. --- changelog.d/14490.feature | 1 + changelog.d/14490.misc | 1 - changelog.d/14525.feature | 1 + synapse/crypto/keyring.py | 107 +++++++++++--------------- tests/crypto/test_keyring.py | 14 +--- tests/rest/key/v2/test_remote_key_resource.py | 5 +- 6 files changed, 47 insertions(+), 82 deletions(-) create mode 100644 changelog.d/14490.feature delete mode 100644 changelog.d/14490.misc create mode 100644 changelog.d/14525.feature (limited to 'synapse') diff --git a/changelog.d/14490.feature b/changelog.d/14490.feature new file mode 100644 index 0000000000..c7cb571294 --- /dev/null +++ b/changelog.d/14490.feature @@ -0,0 +1 @@ +Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`. diff --git a/changelog.d/14490.misc b/changelog.d/14490.misc deleted file mode 100644 index c0a4daa885..0000000000 --- a/changelog.d/14490.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash. diff --git a/changelog.d/14525.feature b/changelog.d/14525.feature new file mode 100644 index 0000000000..c7cb571294 --- /dev/null +++ b/changelog.d/14525.feature @@ -0,0 +1 @@ +Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index ed15f88350..69310d9035 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,7 +14,6 @@ import abc import logging -import urllib from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr @@ -813,31 +812,27 @@ class ServerKeyFetcher(BaseV2KeyFetcher): results = {} - async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None: + async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None: server_name = key_to_fetch_item.server_name - key_ids = key_to_fetch_item.key_ids try: - keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) + keys = await self.get_server_verify_keys_v2_direct(server_name) results[server_name] = keys except KeyLookupError as e: - logger.warning( - "Error looking up keys %s from %s: %s", key_ids, server_name, e - ) + logger.warning("Error looking up keys from %s: %s", server_name, e) except Exception: - logger.exception("Error getting keys %s from %s", key_ids, server_name) + logger.exception("Error getting keys from %s", server_name) - await yieldable_gather_results(get_key, keys_to_fetch) + await yieldable_gather_results(get_keys, keys_to_fetch) return results - async def get_server_verify_key_v2_direct( - self, server_name: str, key_ids: Iterable[str] + async def get_server_verify_keys_v2_direct( + self, server_name: str ) -> Dict[str, FetchKeyResult]: """ Args: - server_name: - key_ids: + server_name: Server to request keys from Returns: Map from key ID to lookup result @@ -845,57 +840,41 @@ class ServerKeyFetcher(BaseV2KeyFetcher): Raises: KeyLookupError if there was a problem making the lookup """ - keys: Dict[str, FetchKeyResult] = {} - - for requested_key_id in key_ids: - # we may have found this key as a side-effect of asking for another. - if requested_key_id in keys: - continue - - time_now_ms = self.clock.time_msec() - try: - response = await self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id, safe=""), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - # these both have str() representations which we can't really improve - # upon - raise KeyLookupError(str(e)) - except HttpResponseException as e: - raise KeyLookupError("Remote server returned an error: %s" % (e,)) - - assert isinstance(response, dict) - if response["server_name"] != server_name: - raise KeyLookupError( - "Expected a response for server %r not %r" - % (server_name, response["server_name"]) - ) - - response_keys = await self.process_v2_response( - from_server=server_name, - response_json=response, - time_added_ms=time_now_ms, + time_now_ms = self.clock.time_msec() + try: + response = await self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server", + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, ) - await self.store.store_server_verify_keys( - server_name, - time_now_ms, - ((server_name, key_id, key) for key_id, key in response_keys.items()), + except (NotRetryingDestination, RequestSendFailed) as e: + # these both have str() representations which we can't really improve + # upon + raise KeyLookupError(str(e)) + except HttpResponseException as e: + raise KeyLookupError("Remote server returned an error: %s" % (e,)) + + assert isinstance(response, dict) + if response["server_name"] != server_name: + raise KeyLookupError( + "Expected a response for server %r not %r" + % (server_name, response["server_name"]) ) - keys.update(response_keys) - return keys + return await self.process_v2_response( + from_server=server_name, + response_json=response, + time_added_ms=time_now_ms, + ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 63628aa6b0..f7c309cad0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -433,7 +433,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) - self.assertEqual(path, "/_matrix/key/v2/server/key1") + self.assertEqual(path, "/_matrix/key/v2/server") return response self.http_client.get_json.side_effect = get_json @@ -469,18 +469,6 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) - def test_keyid_containing_forward_slash(self) -> None: - """We should url-encode any url unsafe chars in key ids. - - Detects https://github.com/matrix-org/synapse/issues/14488. - """ - fetcher = ServerKeyFetcher(self.hs) - self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) - - self.http_client.get_json.assert_called_once() - args, kwargs = self.http_client.get_json.call_args - self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") - class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 7f1fba1086..2bb6e27d94 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import urllib.parse from io import BytesIO, StringIO from typing import Any, Dict, Optional, Union from unittest.mock import Mock @@ -65,9 +64,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) - self.assertEqual( - path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) - ) + self.assertEqual(path, "/_matrix/key/v2/server") response = { "server_name": server_name, -- cgit 1.5.1 From 4569eda94423a10abb69e0f4d5f37eb723ed764b Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 30 Nov 2022 13:39:47 +0100 Subject: Use servers list approx to send read receipts when in partial state (#14549) Signed-off-by: Mathieu Velten --- changelog.d/14549.misc | 1 + synapse/federation/sender/__init__.py | 2 +- tests/federation/test_federation_sender.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14549.misc (limited to 'synapse') diff --git a/changelog.d/14549.misc b/changelog.d/14549.misc new file mode 100644 index 0000000000..d9d863dd20 --- /dev/null +++ b/changelog.d/14549.misc @@ -0,0 +1 @@ +Faster joins: use servers list approximation to send read receipts when in partial state instead of waiting for the full state of the room. \ No newline at end of file diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index fc1d8c88a7..30ebd62883 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -647,7 +647,7 @@ class FederationSender(AbstractFederationSender): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains_set = await self._storage_controllers.state.get_current_hosts_in_room( + domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) domains = [ diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 01f147418b..cbc99d30b9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -38,6 +38,10 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): return_value=make_awaitable({"test", "host2"}) ) + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( + hs.get_storage_controllers().state.get_current_hosts_in_room + ) + return hs @override_config({"send_federation": True}) -- cgit 1.5.1 From e8bce8999f21d30affc459755e304a1f4732165c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 30 Nov 2022 13:45:06 +0000 Subject: Aggregate unread notif count query for badge count calculation (#14255) Fetch the unread notification counts used by the badge counts in push notifications for all rooms at once (instead of fetching them per room). --- changelog.d/14255.misc | 1 + synapse/push/push_tools.py | 28 ++-- .../storage/databases/main/event_push_actions.py | 149 +++++++++++++++++++++ tests/storage/test_event_push_actions.py | 47 +++++-- 4 files changed, 198 insertions(+), 27 deletions(-) create mode 100644 changelog.d/14255.misc (limited to 'synapse') diff --git a/changelog.d/14255.misc b/changelog.d/14255.misc new file mode 100644 index 0000000000..39924659c7 --- /dev/null +++ b/changelog.d/14255.misc @@ -0,0 +1 @@ +Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index edeba27a45..7ee07e4bee 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,7 +17,6 @@ from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.util.async_helpers import concurrently_execute async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: @@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge = len(invites) - room_notifs = [] - - async def get_room_unread_count(room_id: str) -> None: - room_notifs.append( - await store.get_unread_event_push_actions_by_room_for_user( - room_id, - user_id, - ) - ) - - await concurrently_execute(get_room_unread_count, joins, 10) - - for notifs in room_notifs: - # Combine the counts from all the threads. - notify_count = notifs.main_timeline.notify_count + sum( - n.notify_count for n in notifs.threads.values() - ) + room_to_count = await store.get_unread_counts_by_room_for_user(user_id) + for room_id, notify_count in room_to_count.items(): + # room_to_count may include rooms which the user has left, + # ignore those. + if room_id not in joins: + continue if notify_count == 0: continue @@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - # return one badge count per conversation badge += 1 else: - # increment the badge count by the number of unread messages in the room + # Increase badge by number of notifications in room + # NOTE: this includes threaded and unthreaded notifications. badge += notify_count + return badge diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b283ab0f9c..7ebe34f773 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -74,6 +74,7 @@ receipt. """ import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, Collection, @@ -95,6 +96,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + PostgresEngine, ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore @@ -463,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result + async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]: + """Get the notification count by room for a user. Only considers notifications, + not highlight or unread counts, and threads are currently aggregated under their room. + + This function is intentionally not cached because it is called to calculate the + unread badge for push notifications and thus the result is expected to change. + + Note that this function assumes the user is a member of the room. Because + summary rows are not removed when a user leaves a room, the caller must + filter out those results from the result. + + Returns: + A map of room ID to notification counts for the given user. + """ + return await self.db_pool.runInteraction( + "get_unread_counts_by_room_for_user", + self._get_unread_counts_by_room_for_user_txn, + user_id, + ) + + def _get_unread_counts_by_room_for_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Dict[str, int]: + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + args.extend([user_id, user_id]) + + receipts_cte = f""" + WITH all_receipts AS ( + SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + {receipt_types_clause} + AND user_id = ? + GROUP BY room_id, thread_id + ) + """ + + receipts_joins = """ + LEFT JOIN ( + SELECT room_id, thread_id, + max_receipt_stream_ordering AS threaded_receipt_stream_ordering + FROM all_receipts + WHERE thread_id IS NOT NULL + ) AS threaded_receipts USING (room_id, thread_id) + LEFT JOIN ( + SELECT room_id, thread_id, + max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering + FROM all_receipts + WHERE thread_id IS NULL + ) AS unthreaded_receipts USING (room_id) + """ + + # First get summary counts by room / thread for the user. We use the max receipt + # stream ordering of both threaded & unthreaded receipts to compare against the + # summary table. + # + # PostgreSQL and SQLite differ in comparing scalar numerics. + if isinstance(self.database_engine, PostgresEngine): + # GREATEST ignores NULLs. + max_clause = """GREATEST( + threaded_receipt_stream_ordering, + unthreaded_receipt_stream_ordering + )""" + else: + # MAX returns NULL if any are NULL, so COALESCE to 0 first. + max_clause = """MAX( + COALESCE(threaded_receipt_stream_ordering, 0), + COALESCE(unthreaded_receipt_stream_ordering, 0) + )""" + + sql = f""" + {receipts_cte} + SELECT eps.room_id, eps.thread_id, notif_count + FROM event_push_summary AS eps + {receipts_joins} + WHERE user_id = ? + AND notif_count != 0 + AND ( + (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause}) + OR last_receipt_stream_ordering = {max_clause} + ) + """ + txn.execute(sql, args) + + seen_thread_ids = set() + room_to_count: Dict[str, int] = defaultdict(int) + + for room_id, thread_id, notif_count in txn: + room_to_count[room_id] += notif_count + seen_thread_ids.add(thread_id) + + # Now get any event push actions that haven't been rotated using the same OR + # join and filter by receipt and event push summary rotated up to stream ordering. + sql = f""" + {receipts_cte} + SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count + FROM event_push_actions AS epa + {receipts_joins} + WHERE user_id = ? + AND epa.notif = 1 + AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering) + AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) + AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) + GROUP BY epa.room_id, epa.thread_id + """ + txn.execute(sql, args) + + for room_id, thread_id, notif_count in txn: + # Note: only count push actions we have valid summaries for with up to date receipt. + if thread_id not in seen_thread_ids: + continue + room_to_count[room_id] += notif_count + + thread_id_clause, thread_ids_args = make_in_list_sql_clause( + self.database_engine, "epa.thread_id", seen_thread_ids + ) + + # Finally re-check event_push_actions for any rooms not in the summary, ignoring + # the rotated up-to position. This handles the case where a read receipt has arrived + # but not been rotated meaning the summary table is out of date, so we go back to + # the push actions table. + sql = f""" + {receipts_cte} + SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count + FROM event_push_actions AS epa + {receipts_joins} + WHERE user_id = ? + AND NOT {thread_id_clause} + AND epa.notif = 1 + AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) + AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) + GROUP BY epa.room_id + """ + + args.extend(thread_ids_args) + txn.execute(sql, args) + + for room_id, notif_count in txn: + room_to_count[room_id] += notif_count + + return room_to_count + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index ee48920f84..5fa8bd2d98 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -156,7 +156,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str - def _assert_counts(noitf_count: int, highlight_count: int) -> None: + def _assert_counts(notif_count: int, highlight_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", @@ -168,13 +168,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, + notify_count=notif_count, unread_count=0, highlight_count=highlight_count, ), ) self.assertEqual(counts.threads, {}) + aggregate_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-aggregate-unread-counts", + self.store._get_unread_counts_by_room_for_user_txn, + user_id, + ) + ) + self.assertEqual(aggregate_counts[room_id], notif_count) + def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( room_id, @@ -283,7 +292,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str def _assert_counts( - noitf_count: int, + notif_count: int, highlight_count: int, thread_notif_count: int, thread_highlight_count: int, @@ -299,7 +308,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, + notify_count=notif_count, unread_count=0, highlight_count=highlight_count, ), @@ -318,6 +327,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): else: self.assertEqual(counts.threads, {}) + aggregate_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-aggregate-unread-counts", + self.store._get_unread_counts_by_room_for_user_txn, + user_id, + ) + ) + self.assertEqual( + aggregate_counts[room_id], notif_count + thread_notif_count + ) + def _create_event( highlight: bool = False, thread_id: Optional[str] = None ) -> str: @@ -454,7 +474,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str def _assert_counts( - noitf_count: int, + notif_count: int, highlight_count: int, thread_notif_count: int, thread_highlight_count: int, @@ -470,7 +490,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, + notify_count=notif_count, unread_count=0, highlight_count=highlight_count, ), @@ -489,6 +509,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): else: self.assertEqual(counts.threads, {}) + aggregate_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-aggregate-unread-counts", + self.store._get_unread_counts_by_room_for_user_txn, + user_id, + ) + ) + self.assertEqual( + aggregate_counts[room_id], notif_count + thread_notif_count + ) + def _create_event( highlight: bool = False, thread_id: Optional[str] = None ) -> str: @@ -646,7 +677,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) return result["event_id"] - def _assert_counts(noitf_count: int, thread_notif_count: int) -> None: + def _assert_counts(notif_count: int, thread_notif_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", @@ -658,7 +689,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual( counts.main_timeline, NotifCounts( - notify_count=noitf_count, unread_count=0, highlight_count=0 + notify_count=notif_count, unread_count=0, highlight_count=0 ), ) if thread_notif_count: -- cgit 1.5.1 From 71f3e53ad010ba8c219f1076d40915b985760ed9 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 1 Dec 2022 13:46:24 +0000 Subject: Add `push.enabled` option to disable push notification calculation (#14551) * Add initial option * changelog * Some more linting --- changelog.d/14551.feature | 1 + docs/usage/configuration/config_documentation.md | 5 +++ synapse/config/push.py | 1 + synapse/push/bulk_push_rule_evaluator.py | 3 ++ tests/push/test_bulk_push_rule_evaluator.py | 45 ++++++++++++++++++++++-- 5 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14551.feature (limited to 'synapse') diff --git a/changelog.d/14551.feature b/changelog.d/14551.feature new file mode 100644 index 0000000000..43b91d2e57 --- /dev/null +++ b/changelog.d/14551.feature @@ -0,0 +1 @@ +Add new `push.enabled` config option to allow opting out of push notification calculation. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 749af12aac..b9bde8f47e 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3355,6 +3355,10 @@ Configuration settings related to push notifications This setting defines options for push notifications. This option has a number of sub-options. They are as follows: +* `enable_push`: Enables or disables push notification calculation. Note, disabling this will also + stop unread counts being calculated for rooms. This mode of operation is intended + for homeservers which may only have bots or appservice users connected, or are otherwise + not interested in push/unread counters. This is enabled by default. * `include_content`: Clients requesting push notifications can either have the body of the message sent in the notification poke along with other details like the sender, or just the event ID and room ID (`event_id_only`). @@ -3375,6 +3379,7 @@ This option has a number of sub-options. They are as follows: Example configuration: ```yaml push: + enable_push: true include_content: false group_unread_count_by_room: false ``` diff --git a/synapse/config/push.py b/synapse/config/push.py index 979b128eae..3b5378e6ea 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -26,6 +26,7 @@ class PushConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: push_config = config.get("push") or {} self.push_include_content = push_config.get("include_content", True) + self.enable_push = push_config.get("enabled", True) self.push_group_unread_count_by_room = push_config.get( "group_unread_count_by_room", True ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d6b377860f..9ed35d8461 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -106,6 +106,7 @@ class BulkPushRuleEvaluator: self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self.should_calculate_push_rules = self.hs.config.push.enable_push self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled @@ -269,6 +270,8 @@ class BulkPushRuleEvaluator: for each event, check if the message should increment the unread count, and insert the results into the event_push_actions_staging table. """ + if not self.should_calculate_push_rules: + return # For batched events the power level events may not have been persisted yet, # so we pass in the batched events. Thus if the event cannot be found in the # database we can check in the batch. diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 594e7937a8..1cd453248e 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -6,10 +6,11 @@ from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.types import create_requester -from tests import unittest +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase, override_config -class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): +class TestBulkPushRuleEvaluator(HomeserverTestCase): servlets = [ admin.register_servlets_for_client_rest_resource, @@ -72,3 +73,43 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + + @override_config({"push": {"enabled": False}}) + def test_action_for_event_by_user_disabled_by_config(self) -> None: + """Ensure that push rules are not calculated when disabled in the config""" + # Create a new user and room. + alice = self.register_user("alice", "pass") + token = self.login(alice, "pass") + + room_id = self.helper.create_room_as( + alice, room_version=RoomVersions.V9.identifier, tok=token + ) + + # Alter the power levels in that room to include stringy and floaty levels. + # We need to suppress the validation logic or else it will reject these dodgy + # values. (Presumably this validation was not always present.) + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(alice) + + # Create a new message event, and try to evaluate it under the dodgy + # power level event. + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.message", + "room_id": room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": alice, + }, + ) + ) + + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] + # should not raise + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + bulk_evaluator._action_for_event_by_user.assert_not_called() -- cgit 1.5.1 From fac8a38525387e344e3595a092578e0ffedd49ae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 2 Dec 2022 10:28:41 -0500 Subject: Properly handle unknown results for the stream change cache. (#14592) StreamChangeCache.get_all_changed_entities can return None to signify it does not have information at the given stream position. Two callers (related to device lists and presence) were treating this response the same as an empty list (i.e. there being no updates). --- changelog.d/14592.bugfix | 1 + synapse/handlers/presence.py | 4 ++-- synapse/storage/databases/main/devices.py | 33 ++++++++++++++++++------------- 3 files changed, 22 insertions(+), 16 deletions(-) create mode 100644 changelog.d/14592.bugfix (limited to 'synapse') diff --git a/changelog.d/14592.bugfix b/changelog.d/14592.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14592.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index cf08737d11..1799174c2f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1764,14 +1764,14 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): Returns: A list of presence states for the given user to receive. """ + updated_users = None if from_key: # Only return updates since the last sync updated_users = self.store.presence_stream_cache.get_all_entities_changed( from_key ) - if not updated_users: - updated_users = [] + if updated_users is not None: # Get the actual presence update for each change users_to_state = await self.get_presence_handler().current_state_for_users( updated_users diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 534f7fc04a..8ba995df3b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -842,12 +842,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_ids, from_key ) - if not user_ids_to_check: + # If an empty set was returned, there's nothing to do. + if user_ids_to_check is not None and not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - changes: Set[str] = set() - stream_id_where_clause = "stream_id > ?" sql_args = [from_key] @@ -858,19 +857,25 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): sql = f""" SELECT DISTINCT user_id FROM device_lists_stream WHERE {stream_id_where_clause} - AND """ - # Query device changes with a batch of users at a time - # Assertion for mypy's benefit; see also - # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions - assert user_ids_to_check is not None - for chunk in batch_iter(user_ids_to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + clause, sql_args + args) - changes.update(user_id for user_id, in txn) + # If the stream change cache gave us no information, fetch *all* + # users between the stream IDs. + if user_ids_to_check is None: + txn.execute(sql, sql_args) + return {user_id for user_id, in txn} + + # Otherwise, fetch changes for the given users. + else: + changes: Set[str] = set() + + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql + " AND " + clause, sql_args + args) + changes.update(user_id for user_id, in txn) return changes -- cgit 1.5.1 From f685318c2aa5d4a54239f7fc444bdaca6ba975bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 2 Dec 2022 13:10:05 -0500 Subject: Use ClientRestResource on both the main process and workers. (#14528) Add logic to ClientRestResource to decide whether to mount servlets or not based on whether the current process is a worker. This is clearer to see what a worker runs than the completely separate / copy & pasted list of servlets being mounted for workers. --- changelog.d/14528.misc | 1 + synapse/app/generic_worker.py | 74 ++--------------------------------------- synapse/rest/__init__.py | 59 ++++++++++++++++++++------------ synapse/rest/client/account.py | 26 ++++++++------- synapse/rest/client/devices.py | 10 +++--- synapse/rest/client/keys.py | 5 +-- synapse/rest/client/register.py | 9 ++--- synapse/rest/client/room.py | 6 ++-- 8 files changed, 71 insertions(+), 119 deletions(-) create mode 100644 changelog.d/14528.misc (limited to 'synapse') diff --git a/changelog.d/14528.misc b/changelog.d/14528.misc new file mode 100644 index 0000000000..4f233feab6 --- /dev/null +++ b/changelog.d/14528.misc @@ -0,0 +1 @@ +Share the `ClientRestResource` for both workers and the main process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 46dc731696..bcc8abe20c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -44,40 +44,8 @@ from synapse.http.server import JsonResource, OptionsResource from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource +from synapse.rest import ClientRestResource from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client import ( - account_data, - events, - initial_sync, - login, - presence, - profile, - push_rule, - read_marker, - receipts, - relations, - room, - room_batch, - room_keys, - sendtodevice, - sync, - tags, - user_directory, - versions, - voip, -) -from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet -from synapse.rest.client.devices import DevicesRestServlet -from synapse.rest.client.keys import ( - KeyChangesServlet, - KeyQueryServlet, - KeyUploadServlet, - OneTimeKeyServlet, -) -from synapse.rest.client.register import ( - RegisterRestServlet, - RegistrationTokenValidityRestServlet, -) from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -200,45 +168,7 @@ class GenericWorkerServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "client": - resource = JsonResource(self, canonical_json=False) - - RegisterRestServlet(self).register(resource) - RegistrationTokenValidityRestServlet(self).register(resource) - login.register_servlets(self, resource) - ThreepidRestServlet(self).register(resource) - WhoamiRestServlet(self).register(resource) - DevicesRestServlet(self).register(resource) - - # Read-only - KeyUploadServlet(self).register(resource) - KeyQueryServlet(self).register(resource) - KeyChangesServlet(self).register(resource) - OneTimeKeyServlet(self).register(resource) - - voip.register_servlets(self, resource) - push_rule.register_servlets(self, resource) - versions.register_servlets(self, resource) - - profile.register_servlets(self, resource) - - sync.register_servlets(self, resource) - events.register_servlets(self, resource) - room.register_servlets(self, resource, is_worker=True) - relations.register_servlets(self, resource) - room.register_deprecated_servlets(self, resource) - initial_sync.register_servlets(self, resource) - room_batch.register_servlets(self, resource) - room_keys.register_servlets(self, resource) - tags.register_servlets(self, resource) - account_data.register_servlets(self, resource) - receipts.register_servlets(self, resource) - read_marker.register_servlets(self, resource) - - sendtodevice.register_servlets(self, resource) - - user_directory.register_servlets(self, resource) - - presence.register_servlets(self, resource) + resource: Resource = ClientRestResource(self) resources[CLIENT_API_PREFIX] = resource diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 28542cd774..14c4e6ebbb 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -29,7 +29,7 @@ from synapse.rest.client import ( initial_sync, keys, knock, - login as v1_login, + login, login_token_request, logout, mutual_rooms, @@ -82,6 +82,10 @@ class ClientRestResource(JsonResource): @staticmethod def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None: + # Some servlets are only registered on the main process (and not worker + # processes). + is_main_process = hs.config.worker.worker_app is None + versions.register_servlets(hs, client_resource) # Deprecated in r0 @@ -92,45 +96,58 @@ class ClientRestResource(JsonResource): events.register_servlets(hs, client_resource) room.register_servlets(hs, client_resource) - v1_login.register_servlets(hs, client_resource) + login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource) - directory.register_servlets(hs, client_resource) + if is_main_process: + directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) - pusher.register_servlets(hs, client_resource) + if is_main_process: + pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) - logout.register_servlets(hs, client_resource) + if is_main_process: + logout.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource) - filter.register_servlets(hs, client_resource) + if is_main_process: + filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) - auth.register_servlets(hs, client_resource) + if is_main_process: + auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) read_marker.register_servlets(hs, client_resource) room_keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) - tokenrefresh.register_servlets(hs, client_resource) + if is_main_process: + tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) account_data.register_servlets(hs, client_resource) - report_event.register_servlets(hs, client_resource) - openid.register_servlets(hs, client_resource) - notifications.register_servlets(hs, client_resource) + if is_main_process: + report_event.register_servlets(hs, client_resource) + openid.register_servlets(hs, client_resource) + notifications.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) - thirdparty.register_servlets(hs, client_resource) + if is_main_process: + thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) - room_upgrade_rest_servlet.register_servlets(hs, client_resource) + if is_main_process: + room_upgrade_rest_servlet.register_servlets(hs, client_resource) room_batch.register_servlets(hs, client_resource) - capabilities.register_servlets(hs, client_resource) - account_validity.register_servlets(hs, client_resource) + if is_main_process: + capabilities.register_servlets(hs, client_resource) + account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) - password_policy.register_servlets(hs, client_resource) - knock.register_servlets(hs, client_resource) + if is_main_process: + password_policy.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin - admin.register_servlets_for_client_rest_resource(hs, client_resource) + if is_main_process: + admin.register_servlets_for_client_rest_resource(hs, client_resource) # unstable - mutual_rooms.register_servlets(hs, client_resource) - login_token_request.register_servlets(hs, client_resource) - rendezvous.register_servlets(hs, client_resource) + if is_main_process: + mutual_rooms.register_servlets(hs, client_resource) + login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 44f622bcce..b4b92f0c99 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -875,19 +875,21 @@ class AccountStatusRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) - EmailThreepidRequestTokenRestServlet(hs).register(http_server) - MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) - AddThreepidEmailSubmitTokenServlet(hs).register(http_server) - AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: + EmailPasswordRequestTokenRestServlet(hs).register(http_server) + PasswordRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + EmailThreepidRequestTokenRestServlet(hs).register(http_server) + MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidEmailSubmitTokenServlet(hs).register(http_server) + AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) - ThreepidAddRestServlet(hs).register(http_server) - ThreepidBindRestServlet(hs).register(http_server) - ThreepidUnbindRestServlet(hs).register(http_server) - ThreepidDeleteRestServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: + ThreepidAddRestServlet(hs).register(http_server) + ThreepidBindRestServlet(hs).register(http_server) + ThreepidUnbindRestServlet(hs).register(http_server) + ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) - if hs.config.experimental.msc3720_enabled: + if hs.config.worker.worker_app is None and hs.config.experimental.msc3720_enabled: AccountStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 69b803f9f8..486c6dbbc5 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -342,8 +342,10 @@ class ClaimDehydratedDeviceServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - DeleteDevicesRestServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: + DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: + DeviceRestServlet(hs).register(http_server) + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index ee038c7192..7873b363c0 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -376,5 +376,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) - SigningKeyUploadServlet(hs).register(http_server) - SignaturesUploadServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: + SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index de810ae3ec..3cb1e7e375 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -949,9 +949,10 @@ def _calculate_registration_flows( def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - EmailRegisterRequestTokenRestServlet(hs).register(http_server) - MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) - UsernameAvailabilityRestServlet(hs).register(http_server) - RegistrationSubmitTokenServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: + EmailRegisterRequestTokenRestServlet(hs).register(http_server) + MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) + RegistrationSubmitTokenServlet(hs).register(http_server) RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 636cc62877..e70aa381f3 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1395,9 +1395,7 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): ) -def register_servlets( - hs: "HomeServer", http_server: HttpServer, is_worker: bool = False -) -> None: +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1421,7 +1419,7 @@ def register_servlets( TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. - if not is_worker: + if hs.config.worker.worker_app is None: RoomForgetRestServlet(hs).register(http_server) -- cgit 1.5.1 From 93ac3c197ebcb56f4e68a93da5bd63b4a96b18f1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 5 Dec 2022 11:30:41 +0000 Subject: Suppress empty body warnings in room servelets (#14600) * Suppress empty body warnings in room servelets We've already decided to allow empty bodies for backwards compat. The change here stops us from emitting a misleading warning; see also https://github.com/matrix-org/synapse/issues/14478#issuecomment-1319157105 * Changelog --- changelog.d/14600.bugfix | 1 + synapse/rest/client/room.py | 14 ++------------ 2 files changed, 3 insertions(+), 12 deletions(-) create mode 100644 changelog.d/14600.bugfix (limited to 'synapse') diff --git a/changelog.d/14600.bugfix b/changelog.d/14600.bugfix new file mode 100644 index 0000000000..c4bf405684 --- /dev/null +++ b/changelog.d/14600.bugfix @@ -0,0 +1 @@ +Suppress a spurious warning when `POST /rooms///`, `POST /join//` receive an empty HTTP request body. diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index e70aa381f3..514eb6afc8 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -396,12 +396,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - try: - content = parse_json_object_from_request(request) - except Exception: - # Turns out we used to ignore the body entirely, and some clients - # cheekily send invalid bodies. - content = {} + content = parse_json_object_from_request(request, allow_empty_body=True) # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore @@ -952,12 +947,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): }: raise AuthError(403, "Guest access not allowed") - try: - content = parse_json_object_from_request(request) - except Exception: - # Turns out we used to ignore the body entirely, and some clients - # cheekily send invalid bodies. - content = {} + content = parse_json_object_from_request(request, allow_empty_body=True) if membership_action == "invite" and all( key in content for key in ("medium", "address") -- cgit 1.5.1 From 501f62d1a62296f79e46e1bd60dc5d1a8b28847d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 5 Dec 2022 13:07:55 +0000 Subject: Faster remote room joins: stream the un-partial-stating of rooms over replication. [rei:frrj/streams/unpsr] (#14473) --- changelog.d/14473.misc | 1 + synapse/handlers/device.py | 2 +- synapse/handlers/federation.py | 4 + synapse/replication/tcp/streams/__init__.py | 3 + synapse/replication/tcp/streams/partial_state.py | 48 +++++ synapse/storage/databases/main/room.py | 237 +++++++++++++++------ .../delta/73/20_un_partial_stated_room_stream.sql | 32 +++ ..._un_partial_stated_room_stream_seq.sql.postgres | 20 ++ 8 files changed, 280 insertions(+), 67 deletions(-) create mode 100644 changelog.d/14473.misc create mode 100644 synapse/replication/tcp/streams/partial_state.py create mode 100644 synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql create mode 100644 synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres (limited to 'synapse') diff --git a/changelog.d/14473.misc b/changelog.d/14473.misc new file mode 100644 index 0000000000..deccd4e91a --- /dev/null +++ b/changelog.d/14473.misc @@ -0,0 +1 @@ +Faster remote room joins: stream the un-partial-stating of rooms over replication. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b1e55e1b9e..d4750a32e6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -996,7 +996,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # Check if we are partially joining any rooms. If so we need to store # all device list updates so that we can handle them correctly once we # know who is in the room. - # TODO(faster joins): this fetches and processes a bunch of data that we don't + # TODO(faster_joins): this fetches and processes a bunch of data that we don't # use. Could be replaced by a tighter query e.g. # SELECT EXISTS(SELECT 1 FROM partial_state_rooms) partial_rooms = await self.store.get_partial_state_room_resync_info() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d92582fd5c..3398fcaf7d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -152,6 +152,7 @@ class FederationHandler: self._federation_event_handler = hs.get_federation_event_handler() self._device_handler = hs.get_device_handler() self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() + self._notifier = hs.get_notifier() self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs @@ -1692,6 +1693,9 @@ class FederationHandler: self._storage_controllers.state.notify_room_un_partial_stated( room_id ) + # Poke the notifier so that other workers see the write to + # the un-partial-stated rooms stream. + self._notifier.notify_replication() # TODO(faster_joins) update room stats and user directory? # https://github.com/matrix-org/synapse/issues/12814 diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index b1cd55bf6f..8575666d9c 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -42,6 +42,7 @@ from synapse.replication.tcp.streams._base import ( ) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.federation import FederationStream +from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream STREAMS_MAP = { stream.NAME: stream @@ -61,6 +62,7 @@ STREAMS_MAP = { TagAccountDataStream, AccountDataStream, UserSignatureStream, + UnPartialStatedRoomStream, ) } @@ -80,4 +82,5 @@ __all__ = [ "TagAccountDataStream", "AccountDataStream", "UserSignatureStream", + "UnPartialStatedRoomStream", ] diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py new file mode 100644 index 0000000000..18f087ffa2 --- /dev/null +++ b/synapse/replication/tcp/streams/partial_state.py @@ -0,0 +1,48 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import attr + +from synapse.replication.tcp.streams import Stream +from synapse.replication.tcp.streams._base import current_token_without_instance + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UnPartialStatedRoomStreamRow: + # ID of the room that has been un-partial-stated. + room_id: str + + +class UnPartialStatedRoomStream(Stream): + """ + Stream to notify about rooms becoming un-partial-stated; + that is, when the background sync finishes such that we now have full state for + the room. + """ + + NAME = "un_partial_stated_room" + ROW_TYPE = UnPartialStatedRoomStreamRow + + def __init__(self, hs: "HomeServer"): + store = hs.get_datastores().main + super().__init__( + hs.get_instance_name(), + # TODO(faster_joins, multiple writers): we need to account for instance names + current_token_without_instance(store.get_un_partial_stated_rooms_token), + store.get_un_partial_stated_rooms_from_stream, + ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 1309bfd374..78906a5e1d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1,5 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,8 +50,14 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import IdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + IdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -114,6 +120,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self.config: HomeServerConfig = hs.config + self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator + + if isinstance(database.engine, PostgresEngine): + self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="un_partial_stated_room_stream", + instance_name=self._instance_name, + tables=[ + ("un_partial_stated_room_stream", "instance_name", "stream_id") + ], + sequence_name="un_partial_stated_room_stream_sequence", + # TODO(faster_joins, multiple writers) Support multiple writers. + writers=["master"], + ) + else: + self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator( + db_conn, "un_partial_stated_room_stream", "stream_id" + ) + async def store_room( self, room_id: str, @@ -1216,70 +1242,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return room_servers - async def clear_partial_state_room(self, room_id: str) -> bool: - """Clears the partial state flag for a room. - - Args: - room_id: The room whose partial state flag is to be cleared. - - Returns: - `True` if the partial state flag has been cleared successfully. - - `False` if the partial state flag could not be cleared because the room - still contains events with partial state. - """ - try: - await self.db_pool.runInteraction( - "clear_partial_state_room", self._clear_partial_state_room_txn, room_id - ) - return True - except self.db_pool.engine.module.IntegrityError as e: - # Assume that any `IntegrityError`s are due to partial state events. - logger.info( - "Exception while clearing lazy partial-state-room %s, retrying: %s", - room_id, - e, - ) - return False - - def _clear_partial_state_room_txn( - self, txn: LoggingTransaction, room_id: str - ) -> None: - DatabasePool.simple_delete_txn( - txn, - table="partial_state_rooms_servers", - keyvalues={"room_id": room_id}, - ) - DatabasePool.simple_delete_one_txn( - txn, - table="partial_state_rooms", - keyvalues={"room_id": room_id}, - ) - self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) - self._invalidate_cache_and_stream( - txn, self.get_partial_state_servers_at_join, (room_id,) - ) - - # We now delete anything from `device_lists_remote_pending` with a - # stream ID less than the minimum - # `partial_state_rooms.device_lists_stream_id`, as we no longer need them. - device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn( - txn, - table="partial_state_rooms", - keyvalues={}, - retcol="MIN(device_lists_stream_id)", - allow_none=True, - ) - if device_lists_stream_id is None: - # There are no rooms being currently partially joined, so we delete everything. - txn.execute("DELETE FROM device_lists_remote_pending") - else: - sql = """ - DELETE FROM device_lists_remote_pending - WHERE stream_id <= ? - """ - txn.execute(sql, (device_lists_stream_id,)) - @cached() async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. @@ -1315,6 +1277,66 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) return result["join_event_id"], result["device_lists_stream_id"] + def get_un_partial_stated_rooms_token(self) -> int: + # TODO(faster_joins, multiple writers): This is inappropriate if there + # are multiple writers because workers that don't write often will + # hold all readers up. + # (See `MultiWriterIdGenerator.get_persisted_upto_position` for an + # explanation.) + return self._un_partial_stated_rooms_stream_id_gen.get_current_token() + + async def get_un_partial_stated_rooms_from_stream( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_un_partial_stated_rooms_from_stream_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: + sql = """ + SELECT stream_id, room_id + FROM un_partial_stated_room_stream + WHERE ? < stream_id AND stream_id <= ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, instance_name, limit)) + updates = [(row[0], (row[1],)) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_un_partial_stated_rooms_from_stream", + get_un_partial_stated_rooms_from_stream_txn, + ) + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1806,6 +1828,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") + self._instance_name = hs.get_instance_name() + async def upsert_room_on_join( self, room_id: str, room_version: RoomVersion, state_events: List[EventBase] ) -> None: @@ -2270,3 +2294,84 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self.is_room_blocked, (room_id,), ) + + async def clear_partial_state_room(self, room_id: str) -> bool: + """Clears the partial state flag for a room. + + Args: + room_id: The room whose partial state flag is to be cleared. + + Returns: + `True` if the partial state flag has been cleared successfully. + + `False` if the partial state flag could not be cleared because the room + still contains events with partial state. + """ + try: + async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id: + await self.db_pool.runInteraction( + "clear_partial_state_room", + self._clear_partial_state_room_txn, + room_id, + un_partial_state_room_stream_id, + ) + return True + except self.db_pool.engine.module.IntegrityError as e: + # Assume that any `IntegrityError`s are due to partial state events. + logger.info( + "Exception while clearing lazy partial-state-room %s, retrying: %s", + room_id, + e, + ) + return False + + def _clear_partial_state_room_txn( + self, + txn: LoggingTransaction, + room_id: str, + un_partial_state_room_stream_id: int, + ) -> None: + DatabasePool.simple_delete_txn( + txn, + table="partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + ) + DatabasePool.simple_delete_one_txn( + txn, + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + ) + self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_partial_state_servers_at_join, (room_id,) + ) + + DatabasePool.simple_insert_txn( + txn, + "un_partial_stated_room_stream", + { + "stream_id": un_partial_state_room_stream_id, + "instance_name": self._instance_name, + "room_id": room_id, + }, + ) + + # We now delete anything from `device_lists_remote_pending` with a + # stream ID less than the minimum + # `partial_state_rooms.device_lists_stream_id`, as we no longer need them. + device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn( + txn, + table="partial_state_rooms", + keyvalues={}, + retcol="MIN(device_lists_stream_id)", + allow_none=True, + ) + if device_lists_stream_id is None: + # There are no rooms being currently partially joined, so we delete everything. + txn.execute("DELETE FROM device_lists_remote_pending") + else: + sql = """ + DELETE FROM device_lists_remote_pending + WHERE stream_id <= ? + """ + txn.execute(sql, (device_lists_stream_id,)) diff --git a/synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql b/synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql new file mode 100644 index 0000000000..743196cfe3 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/20_un_partial_stated_room_stream.sql @@ -0,0 +1,32 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stream for notifying that a room has become un-partial-stated. +CREATE TABLE un_partial_stated_room_stream( + -- Position in the stream + stream_id BIGINT PRIMARY KEY NOT NULL, + + -- Which instance wrote this entry. + instance_name TEXT NOT NULL, + + -- Which room has been un-partial-stated. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE +); + +-- We want an index here because of the foreign key constraint: +-- upon deleting a room, the database needs to be able to check here. +-- This index is not unique because we can join a room multiple times in a server's lifetime, +-- so the same room could be un-partial-stated multiple times! +CREATE INDEX un_partial_stated_room_stream_room_id ON un_partial_stated_room_stream (room_id); diff --git a/synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres b/synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres new file mode 100644 index 0000000000..c1aac0b385 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/21_un_partial_stated_room_stream_seq.sql.postgres @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS un_partial_stated_room_stream_sequence; + +SELECT setval('un_partial_stated_room_stream_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM un_partial_stated_room_stream +)); -- cgit 1.5.1 From 6a8310f3dfe77acf59df2fe3e88a71b85b9b3ecc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 5 Dec 2022 09:00:59 -0500 Subject: Compare to the earliest known stream pos in the stream change cache. (#14435) The internal methods of the StreamChangeCache were inconsistently treating the earliest known stream position as valid. It is now treated as invalid, meaning the cache cannot determine if an entity at the earliest known stream position has changed or not. --- changelog.d/14435.bugfix | 1 + poetry.lock | 2 +- pyproject.toml | 3 +- synapse/util/caches/stream_change_cache.py | 142 +++++++++++++++++++++++------ tests/util/test_stream_change_cache.py | 38 +++----- 5 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 changelog.d/14435.bugfix (limited to 'synapse') diff --git a/changelog.d/14435.bugfix b/changelog.d/14435.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14435.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/poetry.lock b/poetry.lock index 8c63134578..90b363a548 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1639,7 +1639,7 @@ url-preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "27811bd21d56ceeb0f68ded5a00375efcd1a004928f0736f5b02927ce8594cb0" +content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index af5ce2aa03..1368e4e688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,8 @@ pyasn1 = ">=0.1.9" pyasn1-modules = ">=0.0.7" bcrypt = ">=3.1.7" Pillow = ">=5.4.0" -sortedcontainers = ">=1.4.4" +# We use SortedDict.peekitem(), which was added in sortedcontainers 1.5.2. +sortedcontainers = ">=1.5.2" pymacaroons = ">=0.13.0" msgpack = ">=0.5.2" phonenumbers = ">=8.2.0" diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 666f4b6895..042de8d7c8 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -27,13 +27,17 @@ EntityType = str class StreamChangeCache: - """Keeps track of the stream positions of the latest change in a set of entities. + """ + Keeps track of the stream positions of the latest change in a set of entities. + + The entity will is typically a room ID or user ID, but can be any string. - Typically the entity will be a room or user id. + Can be queried for whether a specific entity has changed after a stream position + or for a list of changed entities after a stream position. See the individual + methods for more information. - Given a list of entities and a stream position, it will give a subset of - entities that may have changed since that position. If position key is too - old then the cache will simply return all given entities. + Only tracks to a maximum cache size, any position earlier than the earliest + known stream position must be treated as unknown. """ def __init__( @@ -45,16 +49,20 @@ class StreamChangeCache: ) -> None: self._original_max_size: int = max_size self._max_size = math.floor(max_size) - self._entity_to_key: Dict[EntityType, int] = {} - # map from stream id to the a set of entities which changed at that stream id. + # map from stream id to the set of entities which changed at that stream id. self._cache: SortedDict[int, Set[EntityType]] = SortedDict() + # map from entity to the stream ID of the latest change for that entity. + # + # Must be kept in sync with _cache. + self._entity_to_key: Dict[EntityType, int] = {} # the earliest stream_pos for which we can reliably answer # get_all_entities_changed. In other words, one less than the earliest # stream_pos for which we know _cache is valid. # self._earliest_known_stream_pos = current_stream_pos + self.name = name self.metrics = caches.register_cache( "cache", self.name, self._cache, resize_callback=self.set_cache_factor @@ -82,22 +90,46 @@ class StreamChangeCache: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos""" + """ + Returns True if the entity may have been updated after stream_pos. + + Args: + entity: The entity to check for changes. + stream_pos: The stream position to check for changes after. + + Return: + True if the entity may have been updated, this happens if: + * The given stream position is at or earlier than the earliest + known stream position. + * The given stream position is earlier than the latest change for + the entity. + + False otherwise: + * The entity is unknown. + * The given stream position is at or later than the latest change + for the entity. + """ assert isinstance(stream_pos, int) - if stream_pos < self._earliest_known_stream_pos: + # _cache is not valid at or before the earliest known stream position, so + # return that the entity has changed. + if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + # If the entity is unknown, it hasn't changed. latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: self.metrics.inc_hits() return False + # This is a known entity, return true if the stream position is earlier + # than the last change. if stream_pos < latest_entity_change_pos: self.metrics.inc_misses() return True + # Otherwise, the stream position is after the latest change: return false. self.metrics.inc_hits() return False @@ -105,15 +137,27 @@ class StreamChangeCache: self, entities: Collection[EntityType], stream_pos: int ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ - Returns subset of entities that have had new things since the given - position. Entities unknown to the cache will be returned. If the - position is too old it will just return the given list. + Returns the subset of the given entities that have had changes after the given position. + + Entities unknown to the cache will be returned. + + If the position is too old it will just return the given list. + + Args: + entities: Entities to check for changes. + stream_pos: The stream position to check for changes after. + + Return: + A subset of entities which have changed after the given stream position. + + This will be all entities if the given stream position is at or earlier + than the earliest known stream position. """ changed_entities = self.get_all_entities_changed(stream_pos) if changed_entities is not None: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the - # given iterable is already set that we can reuse, otherwise we + # given iterable is already a set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): @@ -130,29 +174,57 @@ class StreamChangeCache: return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed""" - assert type(stream_pos) is int + """ + Returns true if any entity has changed after the given stream position. + + Args: + stream_pos: The stream position to check for changes after. + + Return: + True if any entity has changed after the given stream position or + if the given stream position is at or earlier than the earliest + known stream position. + + False otherwise. + """ + assert isinstance(stream_pos, int) if not self._cache: # If the cache is empty, nothing can have changed. return False - if stream_pos >= self._earliest_known_stream_pos: - self.metrics.inc_hits() - return self._cache.bisect_right(stream_pos) < len(self._cache) - else: + # _cache is not valid at or before the earliest known stream position, so + # return that an entity has changed. + if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + self.metrics.inc_hits() + return stream_pos < self._cache.peekitem()[0] + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: - """Returns all entities that have had new things since the given - position. If the position is too old it will return None. + """ + Returns all entities that have had changes after the given position. + + If the stream change cache does not go far enough back, i.e. the position + is too old, it will return None. Returns the entities in the order that they were changed. + + Args: + stream_pos: The stream position to check for changes after. + + Return: + Entities which have changed after the given stream position. + + None if the given stream position is at or earlier than the earliest + known stream position. """ - assert type(stream_pos) is int + assert isinstance(stream_pos, int) - if stream_pos < self._earliest_known_stream_pos: + # _cache is not valid at or before the earliest known stream position, so + # return None to mark that it is unknown if an entity has changed. + if stream_pos <= self._earliest_known_stream_pos: return None changed_entities: List[EntityType] = [] @@ -162,11 +234,17 @@ class StreamChangeCache: return changed_entities def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: - """Informs the cache that the entity has been changed at the given - position. """ - assert type(stream_pos) is int + Informs the cache that the entity has been changed at the given position. + + Args: + entity: The entity to mark as changed. + stream_pos: The stream position to update the entity to. + """ + assert isinstance(stream_pos, int) + # For a change before _cache is valid (e.g. at or before the earliest known + # stream position) there's nothing to do. if stream_pos <= self._earliest_known_stream_pos: return @@ -189,6 +267,11 @@ class StreamChangeCache: self._evict() def _evict(self) -> None: + """ + Ensure the cache has not exceeded the maximum size. + + Evicts entries until it is at the maximum size. + """ # if the cache is too big, remove entries while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) @@ -199,5 +282,12 @@ class StreamChangeCache: def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an entity. + + Args: + entity: The entity to check. + + Return: + The stream position of the latest change for the given entity or + the earliest known stream position if the entitiy is unknown. """ return self._entity_to_key.get(entity, self._earliest_known_stream_pos) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 1b0fa52ad1..a29cc872f9 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -51,6 +51,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # return True, whether it's a known entity or not. self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) + self.assertTrue(cache.has_entity_changed("user@foo.com", 3)) + self.assertTrue(cache.has_entity_changed("not@here.website", 3)) def test_entity_has_changed_pops_off_start(self) -> None: """ @@ -65,15 +67,14 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # The cache is at the max size, 2 self.assertEqual(len(cache._cache), 2) + # The cache's earliest known position is 2. + self.assertEqual(cache._earliest_known_stream_pos, 2) # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) - self.assertEqual( - cache.get_all_entities_changed(2), - ["bar@baz.net", "user@elsewhere.org"], - ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) + self.assertIsNone(cache.get_all_entities_changed(2)) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -81,10 +82,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(2), + cache.get_all_entities_changed(3), ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertIsNone(cache.get_all_entities_changed(2)) def test_get_all_entities_changed(self) -> None: """ @@ -99,28 +100,15 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - r = cache.get_all_entities_changed(1) + r = cache.get_all_entities_changed(2) - # either of these are valid - ok1 = [ - "user@foo.com", - "bar@baz.net", - "anotheruser@foo.com", - "user@elsewhere.org", - ] - ok2 = [ - "user@foo.com", - "anotheruser@foo.com", - "bar@baz.net", - "user@elsewhere.org", - ] + # Results are ordered so either of these are valid. + ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"] + ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"] self.assertTrue(r == ok1 or r == ok2) - r = cache.get_all_entities_changed(2) - self.assertTrue(r == ok1[1:] or r == ok2[1:]) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(0), None) + self.assertEqual(cache.get_all_entities_changed(1), None) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) -- cgit 1.5.1 From cee9445884eb62c070fb0b03a112a862e8dea7c4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Dec 2022 20:19:14 +0000 Subject: Better return type for `get_all_entities_changed` (#14604) Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not. --- changelog.d/14604.bugfix | 1 + synapse/handlers/appservice.py | 4 +- synapse/handlers/presence.py | 12 ++-- synapse/handlers/sync.py | 6 +- synapse/handlers/typing.py | 8 +-- synapse/storage/databases/main/devices.py | 111 ++++++++++++++++++----------- synapse/util/caches/stream_change_cache.py | 52 ++++++++++---- tests/util/test_stream_change_cache.py | 20 +++--- 8 files changed, 138 insertions(+), 76 deletions(-) create mode 100644 changelog.d/14604.bugfix (limited to 'synapse') diff --git a/changelog.d/14604.bugfix b/changelog.d/14604.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14604.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 66f5b8d108..f68027aaed 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -615,8 +615,8 @@ class ApplicationServicesHandler: ) # Fetch the users who have modified their device list since then. - users_with_changed_device_lists = ( - await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + users_with_changed_device_lists = await self.store.get_all_devices_changed( + from_key, to_key=new_key ) # Filter out any users the application service is not interested in diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1799174c2f..2af90b25a3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): if from_key is not None: # First get all users that have had a presence update - updated_users = stream_change_cache.get_all_entities_changed(from_key) + result = stream_change_cache.get_all_entities_changed(from_key) # Cross-reference users we're interested in with those that have had updates. - if updated_users is not None: + if result.hit: + updated_users = result.entities + # If we have the full list of changes for presence we can # simply check which ones share a room with the user. get_updates_counter.labels("stream").inc() @@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): updated_users = None if from_key: # Only return updates since the last sync - updated_users = self.store.presence_stream_cache.get_all_entities_changed( - from_key - ) + result = self.store.presence_stream_cache.get_all_entities_changed(from_key) + if result.hit: + updated_users = result.entities if updated_users is not None: # Get the actual presence update for each change diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8858b22dd..0b395a104d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1528,10 +1528,12 @@ class SyncHandler: # # If we don't have that info cached then we get all the users that # share a room with our user and check if those users have changed. - changed_users = self.store.get_cached_device_list_changes( + cache_result = self.store.get_cached_device_list_changes( since_token.device_list_key ) - if changed_users is not None: + if cache_result.hit: + changed_users = cache_result.entities + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a0ea719430..3f656ea4f5 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler): if last_id == current_id: return [], current_id, False - changed_rooms: Optional[ - Iterable[str] - ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) + result = self._typing_stream_change_cache.get_all_entities_changed(last_id) - if changed_rooms is None: + if result.hit: + changed_rooms: Iterable[str] = result.entities + else: changed_rooms = self._room_serials rows = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8ba995df3b..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.stream_change_cache import ( + AllEntitiesChangedResult, + StreamChangeCache, +) from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[List[str]]: + ) -> AllEntitiesChangedResult: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ return self._device_list_stream_cache.get_all_entities_changed(from_key) + @cancellable + async def get_all_devices_changed( + self, + from_key: int, + to_key: int, + ) -> Set[str]: + """Get all users whose devices have changed in the given range. + + Args: + from_key: The minimum device lists stream token to query device list + changes for, exclusive. + to_key: The maximum device lists stream token to query device list + changes for, inclusive. + + Returns: + The set of user_ids whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ + + result = self._device_list_stream_cache.get_all_entities_changed(from_key) + + if result.hit: + # We know which users might have changed devices. + if not result.entities: + # If no users then we can return early. + return set() + + # Otherwise we need to filter down the list + return await self.get_users_whose_devices_changed( + from_key, result.entities, to_key + ) + + # If the cache didn't tell us anything, we just need to query the full + # range. + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE ? < stream_id AND stream_id <= ? + """ + + rows = await self.db_pool.execute( + "get_all_devices_changed", + None, + sql, + from_key, + to_key, + ) + return {u for u, in rows} + @cancellable async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Collection[str]] = None, + user_ids: Collection[str], to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - user_ids_to_check: Optional[Collection[str]] - if user_ids is None: - # Get set of all users that have had device list changes since 'from_key' - user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( - from_key - ) - else: - # The same as above, but filter results to only those users in 'user_ids' - user_ids_to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) # If an empty set was returned, there's nothing to do. - if user_ids_to_check is not None and not user_ids_to_check: + if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - stream_id_where_clause = "stream_id > ?" - sql_args = [from_key] - - if to_key: - stream_id_where_clause += " AND stream_id <= ?" - sql_args.append(to_key) + if to_key is None: + to_key = self._device_list_id_gen.get_current_token() - sql = f""" + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ SELECT DISTINCT user_id FROM device_lists_stream - WHERE {stream_id_where_clause} + WHERE ? < stream_id AND stream_id <= ? AND %s """ - # If the stream change cache gave us no information, fetch *all* - # users between the stream IDs. - if user_ids_to_check is None: - txn.execute(sql, sql_args) - return {user_id for user_id, in txn} + changes: Set[str] = set() - # Otherwise, fetch changes for the given users. - else: - changes: Set[str] = set() - - # Query device changes with a batch of users at a time - for chunk in batch_iter(user_ids_to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + " AND " + clause, sql_args + args) - changes.update(user_id for user_id, in txn) + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql % (clause,), [from_key, to_key] + args) + changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 042de8d7c8..c8b17acb59 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -16,6 +16,7 @@ import logging import math from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union +import attr from sortedcontainers import SortedDict from synapse.util import caches @@ -26,6 +27,29 @@ logger = logging.getLogger(__name__) EntityType = str +@attr.s(auto_attribs=True, frozen=True, slots=True) +class AllEntitiesChangedResult: + """Return type of `get_all_entities_changed`. + + Callers must check that there was a cache hit, via `result.hit`, before + using the entities in `result.entities`. + + This specifically does *not* implement helpers such as `__bool__` to ensure + that callers do the correct checks. + """ + + _entities: Optional[List[EntityType]] + + @property + def hit(self) -> bool: + return self._entities is not None + + @property + def entities(self) -> List[EntityType]: + assert self._entities is not None + return self._entities + + class StreamChangeCache: """ Keeps track of the stream positions of the latest change in a set of entities. @@ -153,19 +177,19 @@ class StreamChangeCache: This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ - changed_entities = self.get_all_entities_changed(stream_pos) - if changed_entities is not None: + cache_result = self.get_all_entities_changed(stream_pos) + if cache_result.hit: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the # given iterable is already a set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): - result = entities.intersection(changed_entities) - elif len(changed_entities) < len(entities): - result = set(changed_entities).intersection(entities) + result = entities.intersection(cache_result.entities) + elif len(cache_result.entities) < len(entities): + result = set(cache_result.entities).intersection(entities) else: - result = set(entities).intersection(changed_entities) + result = set(entities).intersection(cache_result.entities) self.metrics.inc_hits() else: result = set(entities) @@ -202,12 +226,12 @@ class StreamChangeCache: self.metrics.inc_hits() return stream_pos < self._cache.peekitem()[0] - def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult: """ Returns all entities that have had changes after the given position. - If the stream change cache does not go far enough back, i.e. the position - is too old, it will return None. + If the stream change cache does not go far enough back, i.e. the + position is too old, it will return None. Returns the entities in the order that they were changed. @@ -215,23 +239,21 @@ class StreamChangeCache: stream_pos: The stream position to check for changes after. Return: - Entities which have changed after the given stream position. - - None if the given stream position is at or earlier than the earliest - known stream position. + A class indicating if we have the requested data cached, and if so + includes the entities in the order they were changed. """ assert isinstance(stream_pos, int) # _cache is not valid at or before the earliest known stream position, so # return None to mark that it is unknown if an entity has changed. if stream_pos <= self._earliest_known_stream_pos: - return None + return AllEntitiesChangedResult(None) changed_entities: List[EntityType] = [] for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): changed_entities.extend(self._cache[k]) - return changed_entities + return AllEntitiesChangedResult(changed_entities) def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """ diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index a29cc872f9..0305741c99 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertIsNone(cache.get_all_entities_changed(2)) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(2).hit) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(3), + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(2)) + self.assertFalse(cache.get_all_entities_changed(2).hit) def test_get_all_entities_changed(self) -> None: """ @@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # Results are ordered so either of these are valid. ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"] ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"] - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(1), None) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(1).hit) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) @@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "anotheruser@foo.com", ] r = cache.get_all_entities_changed(3) - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) def test_has_any_entity_changed(self) -> None: """ -- cgit 1.5.1 From cb59e080627745d089d073d9dac276362d9abaf6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 6 Dec 2022 09:52:55 +0000 Subject: Improve logging and opentracing for to-device message handling (#14598) A batch of changes intended to make it easier to trace to-device messages through the system. The intention here is that a client can set a property org.matrix.msgid in any to-device message it sends. That ID is then included in any tracing or logging related to the message. (Suggestions as to where this field should be documented welcome. I'm not enthusiastic about speccing it - it's very much an optional extra to help with debugging.) I've also generally improved the data we send to opentracing for these messages. --- changelog.d/14598.feature | 1 + synapse/api/constants.py | 3 + synapse/federation/sender/per_destination_queue.py | 2 +- synapse/handlers/appservice.py | 3 - synapse/handlers/devicemessage.py | 36 +++++---- synapse/handlers/sync.py | 26 ++++-- synapse/logging/opentracing.py | 11 ++- synapse/rest/client/sendtodevice.py | 1 - synapse/storage/databases/main/deviceinbox.py | 92 ++++++++++++++++++---- tests/handlers/test_appservice.py | 7 +- 10 files changed, 136 insertions(+), 46 deletions(-) create mode 100644 changelog.d/14598.feature (limited to 'synapse') diff --git a/changelog.d/14598.feature b/changelog.d/14598.feature new file mode 100644 index 0000000000..88d561e286 --- /dev/null +++ b/changelog.d/14598.feature @@ -0,0 +1 @@ +Improve opentracing and logging for to-device message handling. \ No newline at end of file diff --git a/synapse/api/constants.py b/synapse/api/constants.py index bc04a0755b..89723d24fa 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -230,6 +230,9 @@ class EventContentFields: # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" + # an unspecced field added to to-device messages to identify them uniquely-ish + TO_DEVICE_MSGID: Final = "org.matrix.msgid" + class RoomTypes: """Understood values of the room_type field of m.room.create events.""" diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 5af2784f1e..ffc9d95ee7 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -641,7 +641,7 @@ class PerDestinationQueue: if not message_id: continue - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id) edus = [ Edu( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index f68027aaed..5d1d21cdc8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -578,9 +578,6 @@ class ApplicationServicesHandler: device_id, ), messages in recipient_device_to_messages.items(): for message_json in messages: - # Remove 'message_id' from the to-device message, as it's an internal ID - message_json.pop("message_id", None) - message_payload.append( { "to_user_id": user_id, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 444c08bc2e..75e89850f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict -from synapse.api.constants import EduTypes, ToDeviceEventTypes +from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background @@ -216,14 +216,24 @@ class DeviceMessageHandler: """ sender_user_id = requester.user.to_string() - message_id = random_string(16) - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) - - log_kv({"number_of_to_device_messages": len(messages)}) - set_tag("sender", sender_user_id) + set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) + set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) local_messages = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): + # add an opentracing log entry for each message + for device_id, message_content in by_device.items(): + log_kv( + { + "event": "send_to_device_message", + "user_id": user_id, + "device_id": device_id, + EventContentFields.TO_DEVICE_MSGID: message_content.get( + EventContentFields.TO_DEVICE_MSGID + ), + } + ) + # Ratelimit local cross-user key requests by the sending device. if ( message_type == ToDeviceEventTypes.RoomKeyRequest @@ -233,6 +243,7 @@ class DeviceMessageHandler: requester, (sender_user_id, requester.device_id) ) if not allowed: + log_kv({"message": f"dropping key requests to {user_id}"}) logger.info( "Dropping room_key_request from %s to %s due to rate limit", sender_user_id, @@ -247,18 +258,11 @@ class DeviceMessageHandler: "content": message_content, "type": message_type, "sender": sender_user_id, - "message_id": message_id, } for device_id, message_content in by_device.items() } if messages_by_device: local_messages[user_id] = messages_by_device - log_kv( - { - "user_id": user_id, - "device_id": list(messages_by_device), - } - ) else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device @@ -267,7 +271,11 @@ class DeviceMessageHandler: remote_edu_contents = {} for destination, messages in remote_messages.items(): - log_kv({"destination": destination}) + # The EDU contains a "message_id" property which is used for + # idempotence. Make up a random one. + message_id = random_string(16) + log_kv({"destination": destination, "message_id": message_id}) + remote_edu_contents[destination] = { "messages": messages, "sender": sender_user_id, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0b395a104d..dace9b606f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,14 +31,20 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context -from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary @@ -1586,6 +1592,7 @@ class SyncHandler: else: return DeviceListUpdates() + @trace async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" ) -> None: @@ -1605,11 +1612,16 @@ class SyncHandler: ) for message in messages: - # We pop here as we shouldn't be sending the message ID down - # `/sync` - message_id = message.pop("message_id", None) - if message_id: - set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + log_kv( + { + "event": "to_device_message", + "sender": message["sender"], + "type": message["type"], + EventContentFields.TO_DEVICE_MSGID: message["content"].get( + EventContentFields.TO_DEVICE_MSGID + ), + } + ) logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index b69060854f..a705af8356 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -292,8 +292,15 @@ logger = logging.getLogger(__name__) class SynapseTags: - # The message ID of any to_device message processed - TO_DEVICE_MESSAGE_ID = "to_device.message_id" + # The message ID of any to_device EDU processed + TO_DEVICE_EDU_ID = "to_device.edu_id" + + # Details about to-device messages + TO_DEVICE_TYPE = "to_device.type" + TO_DEVICE_SENDER = "to_device.sender" + TO_DEVICE_RECIPIENT = "to_device.recipient" + TO_DEVICE_RECIPIENT_DEVICE = "to_device.recipient_device" + TO_DEVICE_MSGID = "to_device.msgid" # client-generated ID # Whether the sync response has new data to be returned to the client. SYNC_RESULT = "sync.new_data" diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index 46a8b03829..55d52f0b28 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -46,7 +46,6 @@ class SendToDeviceRestServlet(servlet.RestServlet): def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 73c95ffb6f..48a54d9cb8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -26,8 +26,15 @@ from typing import ( cast, ) +from synapse.api.constants import EventContentFields from synapse.logging import issue9533_logger -from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -397,6 +404,17 @@ class DeviceInboxWorkerStore(SQLBaseStore): (recipient_user_id, recipient_device_id), [] ).append(message_dict) + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that @@ -678,12 +696,35 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - if remote_messages_by_destination: - issue9533_logger.debug( - "Queued outgoing to-device messages with stream_id %i for %s", - stream_id, - list(remote_messages_by_destination.keys()), - ) + for destination, edu in remote_messages_by_destination.items(): + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Queued outgoing to-device messages with " + "stream_id %i, EDU message_id %s, type %s for %s: %s", + stream_id, + edu["message_id"], + edu["type"], + destination, + [ + f"{user_id}/{device_id} (msgid " + f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})" + for (user_id, messages_by_device) in edu["messages"].items() + for (device_id, msg) in messages_by_device.items() + ], + ) + + for (user_id, messages_by_device) in edu["messages"].items(): + for (device_id, msg) in messages_by_device.items(): + with start_active_span("store_outgoing_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) + set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg.get(EventContentFields.TO_DEVICE_MSGID), + ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() @@ -801,7 +842,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] - message_json = json_encoder.encode(messages_by_device[device_id]) + + with start_active_span("serialise_to_device_message"): + msg = messages_by_device[device_id] + set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + message_json = json_encoder.encode(msg) + messages_json_for_user[device_id] = message_json if messages_json_for_user: @@ -821,15 +874,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - issue9533_logger.debug( - "Stored to-device messages with stream_id %i for %s", - stream_id, - [ - (user_id, device_id) - for (user_id, messages_by_device) in local_by_user_then_device.items() - for device_id in messages_by_device.keys() - ], - ) + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Stored to-device messages with stream_id %i: %s", + stream_id, + [ + f"{user_id}/{device_id} (msgid " + f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})" + for ( + user_id, + messages_by_device, + ) in messages_by_user_then_device.items() + for (device_id, msg) in messages_by_device.items() + ], + ) class DeviceInboxBackgroundUpdateStore(SQLBaseStore): diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9ed26d87a7..57bfbd7734 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -765,7 +765,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)] messages = { self.exclusive_as_user: { - device_id: to_device_message_content for device_id in fake_device_ids + device_id: { + "type": "test_to_device_message", + "sender": "@some:sender", + "content": to_device_message_content, + } + for device_id in fake_device_ids } } -- cgit 1.5.1 From 9b6224577e7a387bf94f2332301f21e9514286ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 6 Dec 2022 07:23:03 -0500 Subject: Failover on proper error responses. (#14620) When querying a remote server handle a 404/405 with an errcode of M_UNRECOGNIZED as an unimplemented endpoint. --- changelog.d/14620.bugfix | 1 + synapse/federation/federation_client.py | 29 ++++++++++++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14620.bugfix (limited to 'synapse') diff --git a/changelog.d/14620.bugfix b/changelog.d/14620.bugfix new file mode 100644 index 0000000000..cb95a87d92 --- /dev/null +++ b/changelog.d/14620.bugfix @@ -0,0 +1 @@ +Return spec-compliant JSON errors when unknown endpoints are requested. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 8bccc9c60d..137cfb3346 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -771,17 +771,28 @@ class FederationClient(FederationBase): """ if synapse_error is None: synapse_error = e.to_synapse_error() - # There is no good way to detect an "unknown" endpoint. + # MSC3743 specifies that servers should return a 404 or 405 with an errcode + # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or + # to an unknown method, respectively. # - # Dendrite returns a 404 (with a body of "404 page not found"); - # Conduit returns a 404 (with no body); and Synapse returns a 400 - # with M_UNRECOGNIZED. - # - # This needs to be rather specific as some endpoints truly do return 404 - # errors. + # Older versions of servers don't properly handle this. This needs to be + # rather specific as some endpoints truly do return 404 errors. return ( - e.code == 404 and (not e.response or e.response == b"404 page not found") - ) or (e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED) + # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method. + (e.code == 404 or e.code == 405) + and ( + # Older Dendrites returned a text or empty body. + # Older Conduit returned an empty body. + not e.response + or e.response == b"404 page not found" + # The proper response JSON with M_UNRECOGNIZED errcode. + or synapse_error.errcode == Codes.UNRECOGNIZED + ) + ) or ( + # Older Synapses returned a 400 error. + e.code == 400 + and synapse_error.errcode == Codes.UNRECOGNIZED + ) async def _try_destination_list( self, -- cgit 1.5.1 From 9e82caac45cd8eccd7b22c60c2cdbeec9aab3a2d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 6 Dec 2022 15:48:42 +0000 Subject: Faster remote room joins: unblock tasks waiting for full room state when the un-partial-stating of that room is received over the replication stream. [rei:frrj/streams/unpsr] (#14474) --- changelog.d/14474.misc | 1 + synapse/replication/tcp/client.py | 11 ++++ .../replication/tcp/streams/test_partial_state.py | 65 ++++++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 changelog.d/14474.misc create mode 100644 tests/replication/tcp/streams/test_partial_state.py (limited to 'synapse') diff --git a/changelog.d/14474.misc b/changelog.d/14474.misc new file mode 100644 index 0000000000..deccd4e91a --- /dev/null +++ b/changelog.d/14474.misc @@ -0,0 +1 @@ +Faster remote room joins: stream the un-partial-stating of rooms over replication. \ No newline at end of file diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 18252a2958..b4dad47b45 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -36,12 +36,14 @@ from synapse.replication.tcp.streams import ( TagAccountDataStream, ToDeviceStream, TypingStream, + UnPartialStatedRoomStream, ) from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStreamRow from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID from synapse.util.async_helpers import Linearizer, timeout_deferred from synapse.util.metrics import Measure @@ -117,6 +119,7 @@ class ReplicationDataHandler: self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() + self._state_storage_controller = hs.get_storage_controllers().state self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() @@ -236,6 +239,14 @@ class ReplicationDataHandler: self.notifier.notify_user_joined_room( row.data.event_id, row.data.room_id ) + elif stream_name == UnPartialStatedRoomStream.NAME: + for row in rows: + assert isinstance(row, UnPartialStatedRoomStreamRow) + + # Wake up any tasks waiting for the room to be un-partial-stated. + self._state_storage_controller.notify_room_un_partial_stated( + row.room_id + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py new file mode 100644 index 0000000000..2c10eab4db --- /dev/null +++ b/tests/replication/tcp/streams/test_partial_state.py @@ -0,0 +1,65 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet.defer import ensureDeferred + +from synapse.rest.client import room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + + +class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): + servlets = [room.register_servlets] + hijack_auth = True + user_id = "@bob:test" + + def setUp(self): + super().setUp() + self.store = self.hs.get_datastores().main + + def test_un_partial_stated_room_unblocks_over_replication(self) -> None: + """ + Tests that, when a room is un-partial-stated on another worker, + pending calls to `await_full_state` get unblocked. + """ + + # Make a room. + room_id = self.helper.create_room_as("@bob:test") + # Mark the room as partial-stated. + self.get_success( + self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1") + ) + + worker = self.make_worker_hs("synapse.app.generic_worker") + + # On the worker, attempt to get the current hosts in the room + d = ensureDeferred( + worker.get_storage_controllers().state.get_current_hosts_in_room(room_id) + ) + + self.reactor.advance(0.1) + + # This should block + self.assertFalse( + d.called, "get_current_hosts_in_room/await_full_state did not block" + ) + + # On the master, clear the partial state flag. + self.get_success(self.store.clear_partial_state_room(room_id)) + + self.reactor.advance(0.1) + + # The worker should have unblocked + self.assertTrue( + d.called, "get_current_hosts_in_room/await_full_state did not unblock" + ) -- cgit 1.5.1 From cf1059d045640485a5a0b1e3d945b796b0e6f228 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 7 Dec 2022 11:19:43 +0000 Subject: Fix a long-standing bug where the user directory would return 1 more row than requested. (#14631) --- changelog.d/14631.bugfix | 1 + synapse/rest/client/user_directory.py | 4 ++-- synapse/storage/databases/main/user_directory.py | 2 +- tests/storage/test_user_directory.py | 6 ++++++ 4 files changed, 10 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14631.bugfix (limited to 'synapse') diff --git a/changelog.d/14631.bugfix b/changelog.d/14631.bugfix new file mode 100644 index 0000000000..c5376bab9f --- /dev/null +++ b/changelog.d/14631.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory would return 1 more row than requested. \ No newline at end of file diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index 116c982ce6..4670fad608 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -63,8 +63,8 @@ class UserDirectorySearchRestServlet(RestServlet): body = parse_json_object_from_request(request) - limit = body.get("limit", 10) - limit = min(limit, 50) + limit = int(body.get("limit", 10)) + limit = max(min(limit, 50), 0) try: search_term = body["search_term"] diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 044435deab..af9952f513 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -886,7 +886,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): limited = len(results) > limit - return {"limited": limited, "results": results} + return {"limited": limited, "results": results[0:limit]} def _parse_query_sqlite(search_term: str) -> str: diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 5b60cf5285..88c7d5fec0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -448,6 +448,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_limit_correct(self) -> None: + r = self.get_success(self.store.search_user_dir(ALICE, "bob", 1)) + self.assertTrue(r["limited"]) + self.assertEqual(1, len(r["results"])) + @override_config({"user_directory": {"search_all_users": True}}) def test_search_user_dir_stop_words(self) -> None: """Tests that a user can look up another user by searching for the start if its -- cgit 1.5.1 From 96251af50d621ef1250dc22e447669c69f89b3bb Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 7 Dec 2022 13:39:27 +0000 Subject: Fix a bug introduced in v1.67.0 where not specifying a config file or a server URL would lead to the `register_new_matrix_user` script failing. (#14637) --- changelog.d/14637.bugfix | 1 + synapse/_scripts/register_new_matrix_user.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14637.bugfix (limited to 'synapse') diff --git a/changelog.d/14637.bugfix b/changelog.d/14637.bugfix new file mode 100644 index 0000000000..ab6db383c6 --- /dev/null +++ b/changelog.d/14637.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.67.0 where not specifying a config file or a server URL would lead to the `register_new_matrix_user` script failing. \ No newline at end of file diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 0c4504d5d8..2b74a40166 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -222,6 +222,7 @@ def main() -> None: args = parser.parse_args() + config: Optional[Dict[str, Any]] = None if "config" in args and args.config: config = yaml.safe_load(args.config) @@ -229,7 +230,7 @@ def main() -> None: secret = args.shared_secret else: # argparse should check that we have either config or shared secret - assert config + assert config is not None secret = config.get("registration_shared_secret") secret_file = config.get("registration_shared_secret_path") @@ -244,7 +245,7 @@ def main() -> None: if args.server_url: server_url = args.server_url - elif config: + elif config is not None: server_url = _find_client_listener(config) if not server_url: server_url = _DEFAULT_SERVER_URL -- cgit 1.5.1 From 60c3fea3271468dd1f9e9c5fae2d22dd9778293b Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 7 Dec 2022 17:35:41 +0000 Subject: Reject receipt requests with invalid room or event IDs. (#14632) If the room or event IDs are empty or of an invalid form they should be rejected. --- changelog.d/14632.bugfix | 1 + synapse/rest/client/receipts.py | 5 ++- tests/rest/client/test_receipts.py | 76 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14632.bugfix create mode 100644 tests/rest/client/test_receipts.py (limited to 'synapse') diff --git a/changelog.d/14632.bugfix b/changelog.d/14632.bugfix new file mode 100644 index 0000000000..323d10f1b0 --- /dev/null +++ b/changelog.d/14632.bugfix @@ -0,0 +1 @@ +Reject invalid read receipt requests with empty room or event IDs. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 18a282b22c..28b7d30ea8 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import EventID, JsonDict, RoomID from ._base import client_patterns @@ -56,6 +56,9 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) + if not RoomID.is_valid(room_id) or not event_id.startswith(EventID.SIGIL): + raise SynapseError(400, "A valid room ID and event ID must be specified") + if receipt_type not in self._known_receipt_types: raise SynapseError( 400, diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py new file mode 100644 index 0000000000..2a7fcea386 --- /dev/null +++ b/tests/rest/client/test_receipts.py @@ -0,0 +1,76 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, receipts, register +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class ReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + receipts.register_servlets, + synapse.rest.admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + + def test_send_receipt(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/m.read/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + def test_send_receipt_invalid_room_id(self) -> None: + channel = self.make_request( + "POST", + "/rooms/not-a-room-id/receipt/m.read/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["error"], "A valid room ID and event ID must be specified" + ) + + def test_send_receipt_invalid_event_id(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/m.read/not-an-event-id", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["error"], "A valid room ID and event ID must be specified" + ) + + def test_send_receipt_invalid_receipt_type(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/invalid-receipt-type/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) -- cgit 1.5.1 From da777207528513c858395758bf4c023da2c2c1a3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:35:49 -0500 Subject: Check the stream position before checking if the cache is empty. (#14639) An empty cache does not mean the entity has no changed, if it is earlier than the earliest known stream position return that the entity *has* changed since the cache cannot accurately answer that query. --- changelog.d/14639.bugfix | 1 + synapse/util/caches/stream_change_cache.py | 9 +++++---- tests/util/test_stream_change_cache.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14639.bugfix (limited to 'synapse') diff --git a/changelog.d/14639.bugfix b/changelog.d/14639.bugfix new file mode 100644 index 0000000000..8730b10afe --- /dev/null +++ b/changelog.d/14639.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory and room/user stats might be out of sync. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c8b17acb59..1657459549 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -213,16 +213,17 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - if not self._cache: - # If the cache is empty, nothing can have changed. - return False - # _cache is not valid at or before the earliest known stream position, so # return that an entity has changed. if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + # If the cache is empty, nothing can have changed. + if not self._cache: + self.metrics.inc_misses() + return False + self.metrics.inc_hits() return stream_pos < self._cache.peekitem()[0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 0305741c99..3df053493b 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -144,9 +144,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): """ cache = StreamChangeCache("#test", 1) - # With no entities, it returns False for the past, present, and future. - self.assertFalse(cache.has_any_entity_changed(0)) - self.assertFalse(cache.has_any_entity_changed(1)) + # With no entities, it returns True for the past, present, and False for + # the future. + self.assertTrue(cache.has_any_entity_changed(0)) + self.assertTrue(cache.has_any_entity_changed(1)) self.assertFalse(cache.has_any_entity_changed(2)) # We add an entity -- cgit 1.5.1 From 9d8a3234ba1d3ff831a7647f45c67946773d88a7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:37:05 -0500 Subject: Respond with proper error responses on unknown paths. (#14621) Returns a proper 404 with an errcode of M_RECOGNIZED for unknown endpoints per MSC3743. --- changelog.d/14621.bugfix | 1 + synapse/api/errors.py | 6 ++---- synapse/http/server.py | 19 ++++++++++++++++++- synapse/rest/media/v1/media_repository.py | 4 ++-- synapse/util/httpresourcetree.py | 6 ++++-- tests/rest/admin/test_user.py | 2 +- tests/rest/client/test_login_token_request.py | 4 ++-- tests/rest/client/test_rendezvous.py | 2 +- tests/test_server.py | 2 +- 9 files changed, 32 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14621.bugfix (limited to 'synapse') diff --git a/changelog.d/14621.bugfix b/changelog.d/14621.bugfix new file mode 100644 index 0000000000..cb95a87d92 --- /dev/null +++ b/changelog.d/14621.bugfix @@ -0,0 +1 @@ +Return spec-compliant JSON errors when unknown endpoints are requested. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e2cfcea0f2..76ef12ed3a 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -300,10 +300,8 @@ class InteractiveAuthIncompleteError(Exception): class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" - def __init__( - self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED - ): - super().__init__(400, msg, errcode) + def __init__(self, msg: str = "Unrecognized request", code: int = 400): + super().__init__(code, msg, Codes.UNRECOGNIZED) class NotFoundError(SynapseError): diff --git a/synapse/http/server.py b/synapse/http/server.py index 051a1899a0..2563858f3c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -577,7 +577,24 @@ def _unrecognised_request_handler(request: Request) -> NoReturn: Args: request: Unused, but passed in to match the signature of ServletCallback. """ - raise UnrecognizedRequestError() + raise UnrecognizedRequestError(code=404) + + +class UnrecognizedRequestResource(resource.Resource): + """ + Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an + errcode of M_UNRECOGNIZED. + """ + + def render(self, request: SynapseRequest) -> int: + f = failure.Failure(UnrecognizedRequestError(code=404)) + return_json_error(f, request, None) + # A response has already been sent but Twisted requires either NOT_DONE_YET + # or the response bytes as a return value. + return NOT_DONE_YET + + def getChild(self, name: str, request: Request) -> resource.Resource: + return self class RootRedirect(resource.Resource): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 40b0d39eb2..c70e1837af 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -24,7 +24,6 @@ from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred -from twisted.web.resource import Resource from synapse.api.errors import ( FederationDeniedError, @@ -35,6 +34,7 @@ from synapse.api.errors import ( ) from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement +from synapse.http.server import UnrecognizedRequestResource from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process @@ -1046,7 +1046,7 @@ class MediaRepository: return removed_media, len(removed_media) -class MediaRepositoryResource(Resource): +class MediaRepositoryResource(UnrecognizedRequestResource): """File uploading and downloading. Uploads are POSTed to a resource which returns a token which is used to GET diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index a0606851f7..39fab4fe06 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -15,7 +15,9 @@ import logging from typing import Dict -from twisted.web.resource import NoResource, Resource +from twisted.web.resource import Resource + +from synapse.http.server import UnrecognizedRequestResource logger = logging.getLogger(__name__) @@ -49,7 +51,7 @@ def create_resource_tree( for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource: Resource = NoResource() + child_resource: Resource = UnrecognizedRequestResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e8c9457794..5c1ced355f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -3994,7 +3994,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): """ Tests that shadow-banning for a user that is not a local returns a 400 """ - url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/shadow_ban" channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(400, channel.code, msg=channel.json_body) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index c2e1e08811..6aedc1a11c 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -48,13 +48,13 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): def test_disabled(self) -> None: channel = self.make_request("POST", endpoint, {}, access_token=None) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) self.register_user(self.user, self.password) token = self.login(self.user, self.password) channel = self.make_request("POST", endpoint, {}, access_token=token) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3882_enabled": True}}) def test_require_auth(self) -> None: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index ad00a476e1..c0eb5d01a6 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -36,7 +36,7 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase): def test_disabled(self) -> None: channel = self.make_request("POST", endpoint, {}, access_token=None) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) def test_redirect(self) -> None: diff --git a/tests/test_server.py b/tests/test_server.py index 2d9a0257d4..d67d7722a4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" ) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") -- cgit 1.5.1 From c369e956918333c19cfb21def41c8a54f9d09c90 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:40:20 -0500 Subject: Rebuild the user directory and stats tables. (#14643) Due to the various fixes to the StreamChangeCache it is not safe to trust the information in the user directory or room/user stats tables. Rebuild them as background jobs. In particular see da777207528513c858395758bf4c023da2c2c1a3 (#14639), and 6a8310f3dfe77acf59df2fe3e88a71b85b9b3ecc (#14435). Maybe also be related to fac8a38525387e344e3595a092578e0ffedd49ae (#14592). --- changelog.d/14643.bugfix | 1 + .../main/delta/73/22_rebuild_user_dir_stats.sql | 29 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 changelog.d/14643.bugfix create mode 100644 synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql (limited to 'synapse') diff --git a/changelog.d/14643.bugfix b/changelog.d/14643.bugfix new file mode 100644 index 0000000000..8730b10afe --- /dev/null +++ b/changelog.d/14643.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory and room/user stats might be out of sync. diff --git a/synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql b/synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql new file mode 100644 index 0000000000..afab1e4bb7 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/22_rebuild_user_dir_stats.sql @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + -- Set up user directory staging tables. + (7322, 'populate_user_directory_createtables', '{}', NULL), + -- Run through each room and update the user directory according to who is in it. + (7322, 'populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables'), + -- Insert all users into the user directory, if search_all_users is on. + (7322, 'populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms'), + -- Clean up user directory staging tables. + (7322, 'populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users'), + -- Rebuild the room_stats_current and room_stats_state tables. + (7322, 'populate_stats_process_rooms', '{}', NULL), + -- Update the user_stats_current table. + (7322, 'populate_stats_process_users', '{}', NULL) +ON CONFLICT (update_name) DO NOTHING; -- cgit 1.5.1 From a58b550eac9606bf6bba030abe9d1020c893ca02 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Thu, 8 Dec 2022 21:28:02 +0400 Subject: Fix html templates to load images only on HTTPS (#14625) This PR changes http-based image URLs to be https in html templates. This impacts the Synapse SSO error page, where browsers report mixed media content warnings. Also, https://matrix.org/img/vector-logo-email.png is currently broken but the URL has been updated to be https anyway. Signed-off-by: Ashish Kumar --- changelog.d/14625.bugfix | 1 + synapse/res/templates/_base.html | 6 +++--- synapse/res/templates/notice_expiry.html | 6 +++--- synapse/res/templates/notif_mail.html | 6 +++--- 4 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14625.bugfix (limited to 'synapse') diff --git a/changelog.d/14625.bugfix b/changelog.d/14625.bugfix new file mode 100644 index 0000000000..a4d1216690 --- /dev/null +++ b/changelog.d/14625.bugfix @@ -0,0 +1 @@ +Fix html templates to load images only on HTTPS. Contributed by @ashfame. diff --git a/synapse/res/templates/_base.html b/synapse/res/templates/_base.html index 46439fce6a..4b5cc7bcb6 100644 --- a/synapse/res/templates/_base.html +++ b/synapse/res/templates/_base.html @@ -13,13 +13,13 @@
{% if app_name == "Riot" %} - [Riot] + [Riot] {% elif app_name == "Vector" %} - [Vector] + [Vector] {% elif app_name == "Element" %} [Element] {% else %} - [matrix] + [matrix] {% endif %}
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html index 406397aaca..f62038e111 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html @@ -21,13 +21,13 @@ {% if app_name == "Riot" %} - [Riot] + [Riot] {% elif app_name == "Vector" %} - [Vector] + [Vector] {% elif app_name == "Element" %} [Element] {% else %} - [matrix] + [matrix] {% endif %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index 2add9dd859..7da0fff5e9 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -22,13 +22,13 @@ {%- if app_name == "Riot" %} - [Riot] + [Riot] {%- elif app_name == "Vector" %} - [Vector] + [Vector] {%- elif app_name == "Element" %} [Element] {%- else %} - [matrix] + [matrix] {%- endif %} -- cgit 1.5.1 From c2de2ca63060324cf2f80ddf3289b0fd7a4d861b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Dec 2022 09:37:07 +0000 Subject: Delete stale non-e2e devices for users, take 2 (#14595) This should help reduce the number of devices e.g. simple bots the repeatedly login rack up. We only delete non-e2e devices as they should be safe to delete, whereas if we delete e2e devices for a user we may accidentally break their ability to receive e2e keys for a message. --- changelog.d/14595.misc | 1 + synapse/handlers/device.py | 31 +++++++++++- synapse/storage/databases/main/devices.py | 79 ++++++++++++++++++++++++++++++- tests/handlers/test_device.py | 2 +- tests/storage/test_client_ips.py | 4 +- 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14595.misc (limited to 'synapse') diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14595.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d4750a32e6..7674c187ef 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,6 +52,7 @@ from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -421,6 +422,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Prune the user's device list if they already have a lot of devices. + await self._prune_too_many_devices(user_id) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -452,6 +456,31 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _prune_too_many_devices(self, user_id: str) -> None: + """Delete any excess old devices this user may have.""" + device_ids = await self.store.check_too_many_devices_for_user(user_id) + if not device_ids: + return + + # We don't want to block and try and delete tonnes of devices at once, + # so we cap the number of devices we delete synchronously. + first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] + await self.delete_devices(user_id, first_batch) + + if not remaining_device_ids: + return + + # Now spawn a background loop that deletes the rest. + async def _prune_too_many_devices_loop() -> None: + for batch in batch_iter(remaining_device_ids, 10): + await self.delete_devices(user_id, batch) + + await self.clock.sleep(1) + + run_as_background_process( + "_prune_too_many_devices_loop", _prune_too_many_devices_loop + ) + async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -481,7 +510,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e..08ccd46a2b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,6 +1569,72 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + """Check if the user has a lot of devices, and if so return the set of + devices we can prune. + + This does *not* return hidden devices or devices with E2E keys. + """ + + num_devices = await self.db_pool.simple_select_one_onecol( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcol="COALESCE(COUNT(*), 0)", + desc="count_devices", + ) + + # We let users have up to ten devices without pruning. + if num_devices <= 10: + return [] + + # We prune everything older than N days. + max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 + + if num_devices > 50: + # If the user has more than 50 devices, then we chose a last seen + # that ensures we keep at most 50 devices. + sql = """ + SELECT last_seen FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen IS NOT NULL + AND key_json IS NULL + ORDER BY last_seen DESC + LIMIT 1 + OFFSET 50 + """ + + rows = await self.db_pool.execute( + "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) + ) + if rows: + max_last_seen = max(rows[0][0], max_last_seen) + + # Now fetch the devices to delete. + sql = """ + SELECT DISTINCT device_id FROM devices + LEFT JOIN e2e_device_keys_json USING (user_id, device_id) + WHERE + user_id = ? + AND NOT hidden + AND last_seen < ? + AND key_json IS NULL + ORDER BY last_seen + """ + + def check_too_many_devices_for_user_txn( + txn: LoggingTransaction, + ) -> List[str]: + txn.execute(sql, (user_id, max_last_seen)) + return [device_id for device_id, in txn] + + return await self.db_pool.runInteraction( + "check_too_many_devices_for_user", + check_too_many_devices_for_user_txn, + ) + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1627,6 +1693,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, + "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1672,7 +1739,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + @cached(max_entries=0) + async def delete_device(self, user_id: str, device_id: str) -> None: + raise NotImplementedError() + + # Note: sometimes deleting rows out of `device_inbox` can take a long time, + # so we use a cache so that we deduplicate in flight requests to delete + # devices. + @cachedList(cached_method_name="delete_device", list_name="device_ids") + async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: """Deletes several devices. Args: @@ -1709,6 +1784,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) + return {} + async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ce7525e29c..a456bffd63 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -115,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": None, + "last_seen_ts": 1000000, }, device_map["xyz"], ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 49ad3c1324..a9af1babed 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -169,6 +169,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) + last_seen = self.clock.time_msec() + if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -189,7 +191,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": None, + "last_seen": last_seen, }, ], ) -- cgit 1.5.1 From 94bc21e69f89ad873ad7a0deb6d9c4ff3cb480ef Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Dec 2022 13:31:32 +0000 Subject: Limit the number of devices we delete at once (#14649) --- changelog.d/14649.misc | 1 + synapse/handlers/device.py | 4 +++- synapse/storage/databases/main/devices.py | 11 ++++++++--- tests/handlers/test_device.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14649.misc (limited to 'synapse') diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc new file mode 100644 index 0000000000..f9bfc581ad --- /dev/null +++ b/changelog.d/14649.misc @@ -0,0 +1 @@ +Prune user's old devices on login if they have too many. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7674c187ef..c935c7be90 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -458,10 +458,12 @@ class DeviceHandler(DeviceWorkerHandler): async def _prune_too_many_devices(self, user_id: str) -> None: """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id) + device_ids = await self.store.check_too_many_devices_for_user(user_id, 100) if not device_ids: return + logger.info("Pruning %d old devices for user %s", len(device_ids), user_id) + # We don't want to block and try and delete tonnes of devices at once, # so we cap the number of devices we delete synchronously. first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 08ccd46a2b..95d4c0622d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,11 +1569,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user(self, user_id: str) -> List[str]: + async def check_too_many_devices_for_user( + self, user_id: str, limit: int + ) -> List[str]: """Check if the user has a lot of devices, and if so return the set of devices we can prune. This does *not* return hidden devices or devices with E2E keys. + + Returns at most `limit` number of devices, ordered by last seen. """ num_devices = await self.db_pool.simple_select_one_onecol( @@ -1614,7 +1618,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # Now fetch the devices to delete. sql = """ - SELECT DISTINCT device_id FROM devices + SELECT device_id FROM devices LEFT JOIN e2e_device_keys_json USING (user_id, device_id) WHERE user_id = ? @@ -1622,12 +1626,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): AND last_seen < ? AND key_json IS NULL ORDER BY last_seen + LIMIT ? """ def check_too_many_devices_for_user_txn( txn: LoggingTransaction, ) -> List[str]: - txn.execute(sql, (user_id, max_last_seen)) + txn.execute(sql, (user_id, max_last_seen, limit)) return [device_id for device_id, in txn] return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a456bffd63..e51cac9b33 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,6 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler +from synapse.rest import admin +from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.util import Clock @@ -30,6 +32,12 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + account.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() @@ -229,6 +237,29 @@ class DeviceTestCase(unittest.HomeserverTestCase): NotFoundError, ) + def test_login_delete_old_devices(self) -> None: + """Delete old devices if the user already has too many.""" + + user_id = self.register_user("user", "pass") + + # Create a bunch of devices + for _ in range(50): + self.login("user", "pass") + self.reactor.advance(1) + + # Advance the clock for ages (as we only delete old devices) + self.reactor.advance(60 * 60 * 24 * 300) + + # Log in again to start the pruning + self.login("user", "pass") + + # Give the background job time to do its thing + self.reactor.pump([1.0] * 100) + + # We should now only have the most recent device. + devices = self.get_success(self.handler.get_devices_by_user(user_id)) + self.assertEqual(len(devices), 1) + def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. -- cgit 1.5.1 From 3ac412b4e2f8c5ba11dc962b8a9d871c1efdce9b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Dec 2022 12:36:32 -0500 Subject: Require types in tests.storage. (#14646) Adds missing type hints to `tests.storage` package and does not allow untyped definitions. --- changelog.d/14646.misc | 1 + mypy.ini | 14 +-- synapse/storage/databases/main/end_to_end_keys.py | 2 +- tests/storage/databases/main/test_deviceinbox.py | 10 +- tests/storage/databases/main/test_events_worker.py | 27 ++--- tests/storage/databases/main/test_lock.py | 18 +-- tests/storage/databases/main/test_receipts.py | 8 +- tests/storage/databases/main/test_room.py | 10 +- tests/storage/test__base.py | 2 +- tests/storage/test_account_data.py | 12 +- tests/storage/test_appservice.py | 22 ++-- tests/storage/test_base.py | 30 ++--- tests/storage/test_cleanup_extrems.py | 37 +++--- tests/storage/test_client_ips.py | 58 +++++----- tests/storage/test_database.py | 2 +- tests/storage/test_devices.py | 35 ++++-- tests/storage/test_directory.py | 12 +- tests/storage/test_e2e_room_keys.py | 8 +- tests/storage/test_end_to_end_keys.py | 15 ++- tests/storage/test_event_chain.py | 29 +++-- tests/storage/test_event_federation.py | 71 ++++++------ tests/storage/test_event_metrics.py | 2 +- tests/storage/test_events.py | 39 ++++--- tests/storage/test_keys.py | 9 +- tests/storage/test_monthly_active_users.py | 30 ++--- tests/storage/test_purge.py | 15 ++- tests/storage/test_receipts.py | 12 +- tests/storage/test_redaction.py | 125 ++++++++++++--------- tests/storage/test_rollback_worker.py | 15 ++- tests/storage/test_room.py | 24 ++-- tests/storage/test_room_search.py | 10 +- tests/storage/test_state.py | 46 +++++--- tests/storage/test_stream.py | 18 ++- tests/storage/test_transactions.py | 18 ++- tests/storage/test_txn_limit.py | 14 ++- .../util/test_partial_state_events_tracker.py | 30 ++--- 36 files changed, 489 insertions(+), 341 deletions(-) create mode 100644 changelog.d/14646.misc (limited to 'synapse') diff --git a/changelog.d/14646.misc b/changelog.d/14646.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14646.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index c3fbd1a955..a4a1e4511a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.handlers.test_sso] +disallow_untyped_defs = True + [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True @@ -103,16 +106,7 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True -[mypy-tests.storage.test_id_generators] -disallow_untyped_defs = True - -[mypy-tests.storage.test_profile] -disallow_untyped_defs = True - -[mypy-tests.handlers.test_sso] -disallow_untyped_defs = True - -[mypy-tests.storage.test_user_directory] +[mypy-tests.storage.*] disallow_untyped_defs = True [mypy-tests.rest.*] diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 643c47d608..4c691642e2 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_device_keys_for_cs_api( self, - query_list: List[Tuple[str, Optional[str]]], + query_list: Collection[Tuple[str, Optional[str]]], include_displaynames: bool = True, ) -> Dict[str, Dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 50c20c5b92..373707b275 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest import admin from synapse.rest.client import devices +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): devices.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") - def test_background_remove_deleted_devices_from_device_inbox(self): + def test_background_remove_deleted_devices_from_device_inbox(self) -> None: """Test that the background task to delete old device_inboxes works properly.""" # create a valid device @@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): self.assertEqual(1, len(res)) self.assertEqual(res[0], "cur_device") - def test_background_remove_hidden_devices_from_device_inbox(self): + def test_background_remove_hidden_devices_from_device_inbox(self) -> None: """Test that the background task to delete hidden devices from device_inboxes works properly.""" diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5773172ab8..9f33afcca0 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs self.store: EventsWorkerStore = hs.get_datastores().main @@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.event_ids.append(event.event_id) - def test_simple(self): + def test_simple(self) -> None: with LoggingContext(name="test") as ctx: res = self.get_success( self.store.have_seen_events( @@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_persisting_event_invalidates_cache(self): + def test_persisting_event_invalidates_cache(self) -> None: """ Test to make sure that the `have_seen_event` cache is invalidated after we persist an event and returns @@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # That should result in a single db query to lookup self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) - def test_invalidate_cache_by_room_id(self): + def test_invalidate_cache_by_room_id(self) -> None: """ Test to make sure that all events associated with the given `(room_id,)` are invalidated in the `have_seen_event` cache. @@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") @@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # Reset the event cache so the tests start with it empty self.get_success(self.store._get_event_cache.clear()) - def test_simple(self): + def test_simple(self) -> None: """Test that we cache events that we pull from the DB.""" with LoggingContext("test") as ctx: @@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) - def test_event_ref(self): + def test_event_ref(self) -> None: """Test that we reuse events that are still in memory but have fallen out of the cache, rather than requesting them from the DB. """ @@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0) - def test_dedupe(self): + def test_dedupe(self) -> None: """Test that if we request the same event multiple times we only pull it out once. """ @@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" @@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") @@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): unblock: "Deferred[None]" = Deferred() original_runWithConnection = self.store.db_pool.runWithConnection - async def runWithConnection(*args, **kwargs): + # Don't bother with the types here, we just pass into the original function. + async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def] await unblock return await original_runWithConnection(*args, **kwargs) @@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1) self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0) - def test_first_get_event_cancelled(self): + def test_first_get_event_cancelled(self) -> None: """Test cancellation of the first `get_event` call sharing a database fetch. The first `get_event` call is the one which initiates the fetch. We expect the @@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): # The second `get_event` call should complete successfully. self.get_success(get_event2) - def test_second_get_event_cancelled(self): + def test_second_get_event_cancelled(self) -> None: """Test cancellation of the second `get_event` call sharing a database fetch.""" with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): # Cancel the second `get_event` call. diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 3cc2a58d8d..56cb49d9b5 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -15,18 +15,20 @@ from twisted.internet import defer, reactor from twisted.internet.base import ReactorBase from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS +from synapse.util import Clock from tests import unittest class LockTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def test_acquire_contention(self): + def test_acquire_contention(self) -> None: # Track the number of tasks holding the lock. # Should be at most 1. in_lock = 0 @@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase): release_lock: "Deferred[None]" = Deferred() - async def task(): + async def task() -> None: nonlocal in_lock nonlocal max_in_lock @@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase): # At most one task should have held the lock at a time. self.assertEqual(max_in_lock, 1) - def test_simple_lock(self): + def test_simple_lock(self) -> None: """Test that we can take out a lock and that while we hold it nobody else can take it out. """ @@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase): self.get_success(lock3.__aenter__()) self.get_success(lock3.__aexit__(None, None, None)) - def test_maintain_lock(self): + def test_maintain_lock(self) -> None: """Test that we don't time out locks while they're still active""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) @@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase): self.get_success(lock.__aexit__(None, None, None)) - def test_timeout_lock(self): + def test_timeout_lock(self) -> None: """Test that we time out locks if they're not updated for ages""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) @@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase): self.assertFalse(self.get_success(lock.is_still_valid())) - def test_drop(self): + def test_drop(self) -> None: """Test that dropping the context manager means we stop renewing the lock""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) @@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase): lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) self.assertIsNotNone(lock2) - def test_shutdown(self): + def test_shutdown(self) -> None: """Test that shutting down Synapse releases the locks""" # Acquire two locks lock = self.get_success(self.store.try_acquire_lock("name", "key1")) diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index c4f12d81d7..68026e2830 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") @@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): table: str, receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], - ): + ) -> None: """Test that the background update to uniqueify non-thread receipts in the given receipts table works properly. @@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): f"Background update did not remove all duplicate receipts from {table}", ) - def test_background_receipts_linearized_unique_index(self): + def test_background_receipts_linearized_unique_index(self) -> None: """Test that the background update to uniqueify non-thread receipts in `receipts_linearized` works properly. """ @@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): }, ) - def test_background_receipts_graph_unique_index(self): + def test_background_receipts_graph_unique_index(self) -> None: """Test that the background update to uniqueify non-thread receipts in `receipts_graph` works properly. """ diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 1edb619630..7d961fac64 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -14,10 +14,14 @@ import json +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import RoomTypes from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage.databases.main.room import _BackgroundUpdates +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") @@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): return room_id - def test_background_populate_rooms_creator_column(self): + def test_background_populate_rooms_creator_column(self) -> None: """Test that the background update to populate the rooms creator column works properly. """ @@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(room_creator_after, self.user_id) - def test_background_add_room_type_column(self): + def test_background_add_room_type_column(self) -> None: """Test that the background update to populate the `room_type` column in `room_stats_state` works properly. """ diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 09cb06d614..8bbf936ae9 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): {(1, "user1", "hello"), (2, "user2", "bleb")}, ) - def test_simple_update_many(self): + def test_simple_update_many(self) -> None: """ simple_update_many performs many updates at once. """ diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 72bf5b3d31..1bfd11ceae 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -14,13 +14,17 @@ from typing import Iterable, Optional, Set +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import AccountDataTypes +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest class IgnoredUsersTestCase(unittest.HomeserverTestCase): - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main self.user = "@user:test" @@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): expected_ignored_user_ids, ) - def test_ignoring_users(self): + def test_ignoring_users(self) -> None: """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") self.assert_ignored(self.user, {"@other:test", "@another:remote"}) @@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # Check the removed user. self.assert_ignorers("@another:remote", {self.user}) - def test_caching(self): + def test_caching(self) -> None: """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") @@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) - def test_invalid_data(self): + def test_invalid_data(self) -> None: """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1047ed09c8..5e1324a169 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError from synapse.events import EventBase from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, @@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): - def setUp(self): + def setUp(self) -> None: super(ApplicationServiceStoreTestCase, self).setUp() self.as_yaml_files: List[str] = [] @@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): super(ApplicationServiceStoreTestCase, self).tearDown() - def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: + def _add_appservice( + self, as_token: str, id: str, url: str, hs_token: str, sender: str + ) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): database, make_conn(db_config, self.engine, "test"), self.hs ) - def _add_service(self, url, as_token, id) -> None: + def _add_service(self, url: str, as_token: str, id: str) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state(self, id: str, state: ApplicationServiceState): + def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred: return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_state(as_id, state) VALUES(?,?)" @@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (id, state.value), ) - def _insert_txn(self, as_id, txn_id, events): + def _insert_txn( + self, as_id: str, txn_id: int, events: List[Mock] + ) -> "defer.Deferred[None]": return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " @@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: DatabasePool, db_conn, hs) -> None: + def __init__( + self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer + ) -> None: super().__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): - def _write_config(self, suffix, **kwargs) -> str: + def _write_config(self, suffix: str, **kwargs: str) -> str: vals = { "id": "id" + suffix, "url": "url" + suffix, diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 40e58f8199..256d28e4c9 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict +from typing import Generator from unittest.mock import Mock from twisted.internet import defer @@ -30,7 +30,7 @@ from tests.utils import default_config class SQLBaseStoreTestCase(unittest.TestCase): """Test the "simple" SQL generating methods in SQLBaseStore.""" - def setUp(self): + def setUp(self) -> None: self.db_pool = Mock(spec=["runInteraction"]) self.mock_txn = Mock() self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) @@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_conn.rollback.return_value = None # Our fake runInteraction just runs synchronously inline - def runInteraction(func, *args, **kwargs): + def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_txn, *args, **kwargs)) self.db_pool.runInteraction = runInteraction - def runWithConnection(func, *args, **kwargs): + def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] return defer.succeed(func(self.mock_conn, *args, **kwargs)) self.db_pool.runWithConnection = runWithConnection @@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] @defer.inlineCallbacks - def test_insert_1col(self): + def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_insert_3cols(self): + def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_select_one_1col(self): + def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) @@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_select_one_3col(self): + def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) @@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_select_one_missing(self): + def test_select_one_missing( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None @@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.assertFalse(ret) @defer.inlineCallbacks - def test_select_list(self): + def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 3 self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) @@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_update_one_1col(self): + def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_update_one_4cols(self): + def test_update_one_4cols( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( @@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_delete_one(self): + def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 1 yield defer.ensureDeferred( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index b998ad42d9..d570684c99 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -15,11 +15,16 @@ import os.path from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage import prepare_database +from synapse.storage.types import Cursor from synapse.types import UserID, create_requester +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Test the background update to clean forward extremities table. """ - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() @@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] - def run_background_update(self): + def run_background_update(self) -> None: """Re run the background update to clean up the extremities.""" # Make sure we don't clash with in progress updates. self.assertTrue( @@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): "delete_forward_extremities.sql", ) - def run_delta_file(txn): + def run_delta_file(txn: Cursor) -> None: prepare_database.executescript(txn, schema_path) self.get_success( @@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): (room_id,) ) - def test_soft_failed_extremities_handled_correctly(self): + def test_soft_failed_extremities_handled_correctly(self) -> None: """Test that extremities are correctly calculated in the presence of soft failed events. @@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.assertEqual(latest_event_ids, [event_id_4]) - def test_basic_cleanup(self): + def test_basic_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of soft failed events. @@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(latest_event_ids, [event_id_b]) - def test_chain_of_fail_cleanup(self): + def test_chain_of_fail_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of soft failed events. @@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(latest_event_ids, [event_id_b]) - def test_forked_graph_cleanup(self): + def test_forked_graph_cleanup(self) -> None: r"""Test that extremities are correctly calculated in the presence of soft failed events. @@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["cleanup_extremities_with_dummy_events"] = True return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() @@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.event_creator = homeserver.get_event_creation_handler() homeserver.config.consent.user_consent_version = self.CONSENT_VERSION - def test_send_dummy_event(self): + def test_send_dummy_event(self) -> None: self._create_extremity_rich_graph() # Pump the reactor repeatedly so that the background updates have a @@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) - def test_send_dummy_events_when_insufficient_power(self): + def test_send_dummy_events_when_insufficient_power(self) -> None: self._create_extremity_rich_graph() # Criple power levels self.helper.send_state( @@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) - def test_expiry_logic(self): + def test_expiry_logic(self) -> None: """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() expires old entries correctly. """ @@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): 0, ) - def _create_extremity_rich_graph(self): + def _create_extremity_rich_graph(self) -> None: """Helper method to create bushy graph on demand""" event_id_start = self.create_and_send_event(self.room_id, self.user) @@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): ) self.assertEqual(len(latest_event_ids), 50) - def _enable_consent_checking(self): + def _enable_consent_checking(self) -> None: """Helper method to enable consent checking""" self.event_creator._block_events_without_consent_error = "No consent from user" consent_uri_builder = Mock() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index a9af1babed..81e4e596e4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -13,15 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client import login +from synapse.server import HomeServer from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.types import UserID +from synapse.util import Clock from tests import unittest from tests.server import make_request @@ -30,14 +35,10 @@ from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs - - def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastores().main + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main - def test_insert_new_client_ip(self): + def test_insert_new_client_ip(self) -> None: self.reactor.advance(12345678) user_id = "@user:id" @@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) - def test_insert_new_client_ip_none_device_id(self): + def test_insert_new_client_ip_none_device_id(self) -> None: """ An insert with a device ID of NULL will not create a new entry, but update an existing entry in the user_ips table. @@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) @parameterized.expand([(False,), (True,)]) - def test_get_last_client_ip_by_device(self, after_persisting: bool): + def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None: """Test `get_last_client_ip_by_device` for persisted and unpersisted data""" self.reactor.advance(12345678) @@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): }, ) - def test_get_last_client_ip_by_device_combined_data(self): + def test_get_last_client_ip_by_device_combined_data(self) -> None: """Test that `get_last_client_ip_by_device` combines persisted and unpersisted data together correctly """ @@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) @parameterized.expand([(False,), (True,)]) - def test_get_user_ip_and_agents(self, after_persisting: bool): + def test_get_user_ip_and_agents(self, after_persisting: bool) -> None: """Test `get_user_ip_and_agents` for persisted and unpersisted data""" self.reactor.advance(12345678) @@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) - def test_get_user_ip_and_agents_combined_data(self): + def test_get_user_ip_and_agents_combined_data(self) -> None: """Test that `get_user_ip_and_agents` combines persisted and unpersisted data together correctly """ @@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) - def test_disabled_monthly_active_user(self): + def test_disabled_monthly_active_user(self) -> None: user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertFalse(active) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) - def test_adding_monthly_active_user_when_full(self): + def test_adding_monthly_active_user_when_full(self) -> None: lots_of_users = 100 user_id = "@user:server" @@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertFalse(active) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) - def test_adding_monthly_active_user_when_space(self): + def test_adding_monthly_active_user_when_space(self) -> None: user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertTrue(active) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) - def test_updating_monthly_active_user_when_space(self): + def test_updating_monthly_active_user_when_space(self) -> None: user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) @@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) - def test_devices_last_seen_bg_update(self): + def test_devices_last_seen_bg_update(self) -> None: # First make sure we have completed all updates. self.wait_for_background_updates() @@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) - def test_old_user_ips_pruned(self): + def test_old_user_ips_pruned(self) -> None: # First make sure we have completed all updates. self.wait_for_background_updates() @@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertEqual(result, []) # But we should still get the correct values for the device - result = self.get_success( + result2 = self.get_success( self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, device_id)] + r = result2[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, @@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs - - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) - def test_request_with_xforwarded(self): + def test_request_with_xforwarded(self) -> None: """ The IP in X-Forwarded-For is entered into the client IPs table. """ @@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): {"request": XForwardedForRequest}, ) - def test_request_from_getPeer(self): + def test_request_from_getPeer(self) -> None: """ The IP returned by getPeer is entered into the client IPs table, if there's no X-Forwarded-For header. """ self._runtest({}, "127.0.0.1", {}) - def _runtest(self, headers, expected_ip, make_request_args): + def _runtest( + self, + headers: Dict[bytes, bytes], + expected_ip: str, + make_request_args: Dict[str, Any], + ) -> None: device_id = "bleb" access_token = self.login("bob", "abc123", device_id=device_id) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index a40fc20ef9..543cce6b3e 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -31,7 +31,7 @@ from tests import unittest class TupleComparisonClauseTestCase(unittest.TestCase): - def test_native_tuple_comparison(self): + def test_native_tuple_comparison(self) -> None: clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 8e7db2c4ec..f03807c8f9 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -12,17 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, List, Tuple + +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors from synapse.api.constants import EduTypes +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def add_device_change(self, user_id, device_ids, host): + def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: """Add a device list change for the given device to `device_lists_outbound_pokes` table. """ @@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase): ) ) - def test_store_new_device(self): + def test_store_new_device(self) -> None: self.get_success( self.store.store_device("user_id", "device_id", "display_name") ) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertDictContainsSubset( { "user_id": "user_id", @@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase): res, ) - def test_get_devices_by_user(self): + def test_get_devices_by_user(self) -> None: self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) @@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase): res["device2"], ) - def test_count_devices_by_users(self): + def test_count_devices_by_users(self) -> None: self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) @@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase): ) self.assertEqual(3, res) - def test_get_device_updates_by_remote(self): + def test_get_device_updates_by_remote(self) -> None: device_ids = ["device_id1", "device_id2"] # Add two device updates with sequential `stream_id`s @@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - def test_get_device_updates_by_remote_can_limit_properly(self): + def test_get_device_updates_by_remote_can_limit_properly(self) -> None: """ Tests that `get_device_updates_by_remote` returns an appropriate stream_id to resume fetching from (without skipping any results). @@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase): ) self.assertEqual(device_updates, []) - def _check_devices_in_updates(self, expected_device_ids, device_updates): + def _check_devices_in_updates( + self, + expected_device_ids: Collection[str], + device_updates: List[Tuple[str, JsonDict]], + ) -> None: """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) @@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase): } self.assertEqual(received_device_ids, set(expected_device_ids)) - def test_update_device(self): + def test_update_device(self) -> None: self.get_success( self.store.store_device("user_id", "device_id", "display_name 1") ) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 1", res["display_name"]) # do a no-op first self.get_success(self.store.update_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase): # check it worked res = self.get_success(self.store.get_device("user_id", "device_id")) + assert res is not None self.assertEqual("display_name 2", res["display_name"]) - def test_update_unknown_device(self): + def test_update_unknown_device(self) -> None: exc = self.get_failure( self.store.update_device( "user_id", "unknown_device_id", new_display_name="display_name 2" diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 20bf3ca17b..8bedc6bdf3 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -12,19 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer from synapse.types import RoomAlias, RoomID +from synapse.util import Clock from tests.unittest import HomeserverTestCase class DirectoryStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") - def test_room_to_alias(self): + def test_room_to_alias(self) -> None: self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] @@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) - def test_alias_to_room(self): + def test_alias_to_room(self) -> None: self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] @@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): (self.get_success(self.store.get_association_from_room_alias(self.alias))), ) - def test_delete_alias(self): + def test_delete_alias(self) -> None: self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index fb96ab3a2f..9cb326d90a 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer from synapse.storage.databases.main.e2e_room_keys import RoomKey +from synapse.util import Clock from tests import unittest @@ -26,12 +30,12 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.store = hs.get_datastores().main return hs - def test_room_keys_version_delete(self): + def test_room_keys_version_delete(self) -> None: # test that deleting a room key backup deletes the keys version1 = self.get_success( self.store.create_e2e_room_keys_version( diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 0f04493ad0..5fde3b9c78 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.util import Clock + from tests.unittest import HomeserverTestCase class EndToEndKeyStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def test_key_without_device_name(self): + def test_key_without_device_name(self) -> None: now = 1470174257070 json = {"key": "value"} @@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): dev = res["user"]["device"] self.assertDictContainsSubset(json, dev) - def test_reupload_key(self): + def test_reupload_key(self) -> None: now = 1470174257070 json = {"key": "value"} @@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): ) self.assertFalse(changed) - def test_get_key_with_device_name(self): + def test_get_key_with_device_name(self) -> None: now = 1470174257070 json = {"key": "value"} @@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) - def test_multiple_devices(self): + def test_multiple_devices(self) -> None: now = 1470174257070 self.get_success(self.store.store_device("user1", "device1", None)) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index de9f4af2de..c070278db8 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -14,6 +14,7 @@ from typing import Dict, List, Set, Tuple +from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from synapse.api.constants import EventTypes @@ -22,18 +23,22 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.events import _LinkMap +from synapse.storage.types import Cursor from synapse.types import create_requester +from synapse.util import Clock from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self._next_stream_ordering = 1 - def test_simple(self): + def test_simple(self) -> None: """Test that the example in `docs/auth_chain_difference_algorithm.md` works. """ @@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase): ), ) - def test_out_of_order_events(self): + def test_out_of_order_events(self) -> None: """Test that we handle persisting events that we don't have the full auth chain for yet (which should only happen for out of band memberships). """ @@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase): def persist( self, events: List[EventBase], - ): + ) -> None: """Persist the given events and check that the links generated match those given. """ @@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase): e.internal_metadata.stream_ordering = self._next_stream_ordering self._next_stream_ordering += 1 - def _persist(txn): + def _persist(txn: LoggingTransaction) -> None: # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( @@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase): class LinkMapTestCase(unittest.TestCase): - def test_simple(self): + def test_simple(self) -> None: """Basic tests for the LinkMap.""" link_map = _LinkMap() @@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") @@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Delete the chain cover info. - def _delete_tables(txn): + def _delete_tables(txn: Cursor) -> None: txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") @@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): return room_id, [state1, state2] - def test_background_update_single_room(self): + def test_background_update_single_room(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ @@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) - def test_background_update_multiple_rooms(self): + def test_background_update_multiple_rooms(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ @@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) - def test_background_update_single_large_room(self): + def test_background_update_single_large_room(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ @@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) - def test_background_update_multiple_large_room(self): + def test_background_update_multiple_large_room(self) -> None: """Test that the background update to calculate auth chains for historic rooms works correctly. """ diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 853db930d6..7fd3e01364 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,7 +13,7 @@ # limitations under the License. import datetime -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, cast import attr from parameterized import parameterized @@ -26,11 +26,12 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersion, ) -from synapse.events import _EventInternalMetadata +from synapse.events import EventBase, _EventInternalMetadata from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import Clock, json_encoder @@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def test_get_prev_events_for_room(self): + def test_get_prev_events_for_room(self) -> None: room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities - def insert_event(txn, i): + def insert_event(txn: Cursor, i: int) -> None: event_id = "$event_%i:local" % i txn.execute( @@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for i in range(0, 10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) - def test_get_rooms_with_many_extremities(self): + def test_get_rooms_with_many_extremities(self) -> None: room1 = "#room1" room2 = "#room2" room3 = "#room3" - def insert_event(txn, i, room_id): + def insert_event(txn: Cursor, i: int, room_id: str) -> None: event_id = "$event_%i:local" % i txn.execute( ( @@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # | | # K J - auth_graph = { + auth_graph: Dict[str, List[str]] = { "a": ["e"], "b": ["e"], "c": ["g", "i"], @@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Mark the room as maybe having a cover index. - def store_room(txn): + def store_room(txn: LoggingTransaction) -> None: self.store.db_pool.simple_insert_txn( txn, "rooms", @@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn): + def insert_event(txn: LoggingTransaction) -> None: stream_ordering = 0 for event_id in auth_graph: @@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, [ - FakeEvent(event_id, room_id, auth_graph[event_id]) + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) for event_id in auth_graph ], ) @@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): return room_id @parameterized.expand([(True,), (False,)]) - def test_auth_chain_ids(self, use_chain_cover_index: bool): + def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None: room_id = self._setup_auth_chain(use_chain_cover_index) # a and b have the same auth chain. @@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.assertCountEqual(auth_chain_ids, ["i", "j"]) @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def test_auth_difference(self, use_chain_cover_index: bool) -> None: room_id = self._setup_auth_chain(use_chain_cover_index) # Now actually test that various combinations give the right result: @@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertSetEqual(difference, set()) - def test_auth_difference_partial_cover(self): + def test_auth_difference_partial_cover(self) -> None: """Test that we correctly handle rooms where not all events have a chain cover calculated. This can happen in some obscure edge cases, including during the background update that calculates the chain cover for old @@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # | | # K J - auth_graph = { + auth_graph: Dict[str, List[str]] = { "a": ["e"], "b": ["e"], "c": ["g", "i"], @@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn): + def insert_event(txn: LoggingTransaction) -> None: # First insert the room and mark it as having a chain cover. self.store.db_pool.simple_insert_txn( txn, @@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, [ - FakeEvent(event_id, room_id, auth_graph[event_id]) + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) for event_id in auth_graph if event_id != "b" ], @@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, - [FakeEvent("b", room_id, auth_graph["b"])], + [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], ) self.store.db_pool.simple_update_txn( @@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): @parameterized.expand( [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()] ) - def test_prune_inbound_federation_queue(self, room_version: RoomVersion): + def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None: """Test that pruning of inbound federation queues work""" room_id = "some_room_id" @@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): stream_ordering += 1 - def populate_db(txn: LoggingTransaction): + def populate_db(txn: LoggingTransaction) -> None: # Insert the room to satisfy the foreign key constraint of # `event_failed_pull_attempts` self.store.db_pool.simple_insert_txn( @@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) - def test_get_backfill_points_in_room(self): + def test_get_backfill_points_in_room(self) -> None: """ Test to make sure only backfill points that are older and come before the `current_depth` are returned. @@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_backfill_points_in_room_excludes_events_we_have_attempted( self, - ): + ) -> None: """ Test to make sure that events we have attempted to backfill (and within backoff timeout duration) do not show up as an event to backfill again. @@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( self, - ): + ) -> None: """ Test to make sure after we fake attempt to backfill event "b3" many times, we can see retry and see the "b3" again after the backoff timeout duration @@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "5": 7, } - def populate_db(txn: LoggingTransaction): + def populate_db(txn: LoggingTransaction) -> None: # Insert the room to satisfy the foreign key constraint of # `event_failed_pull_attempts` self.store.db_pool.simple_insert_txn( @@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) - def test_get_insertion_event_backward_extremities_in_room(self): + def test_get_insertion_event_backward_extremities_in_room(self) -> None: """ Test to make sure only insertion event backward extremities that are older and come before the `current_depth` are returned. @@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( self, - ): + ) -> None: """ Test to make sure that insertion events we have attempted to backfill (and within backoff timeout duration) do not show up as an event to @@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( self, - ): + ) -> None: """ Test to make sure after we fake attempt to backfill event "insertion_eventA" many times, we can see retry and see the @@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["insertion_eventA"]) - def test_get_event_ids_to_not_pull_from_backoff( - self, - ): + def test_get_event_ids_to_not_pull_from_backoff(self) -> None: """ Test to make sure only event IDs we should backoff from are returned. """ @@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( self, - ): + ) -> None: """ Test to make sure no event IDs are returned after the backoff duration has elapsed. @@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.assertEqual(event_ids_to_backoff, []) -@attr.s +@attr.s(auto_attribs=True) class FakeEvent: - event_id = attr.ib() - room_id = attr.ib() - auth_events = attr.ib() + event_id: str + room_id: str + auth_events: List[str] type = "foo" state_key = "foo" internal_metadata = _EventInternalMetadata({}) - def auth_event_ids(self): + def auth_event_ids(self) -> List[str]: return self.auth_events - def is_state(self): + def is_state(self) -> bool: return True diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 6f1135eef4..a91411168c 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class ExtremStatisticsTestCase(HomeserverTestCase): - def test_exposed_to_prometheus(self): + def test_exposed_to_prometheus(self) -> None: """ Forward extremity counts are exposed via Prometheus. """ diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 3ce4f35cb7..05661a537d 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import StateMap +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence self._state_storage_controller = self.hs.get_storage_controllers().state @@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check that the current extremities is the remote event. self.assert_extremities([self.remote_event_1.event_id]) - def persist_event(self, event, state=None): + def persist_event( + self, event: EventBase, state: Optional[StateMap[str]] = None + ) -> None: """Persist the event, with optional state""" context = self.get_success( self.state.compute_event_context( @@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase): ) self.get_success(self._persistence.persist_event(event, context)) - def assert_extremities(self, expected_extremities): + def assert_extremities(self, expected_extremities: List[str]) -> None: """Assert the current extremities for the room""" extremities = self.get_success( self.store.get_prev_events_for_room(self.room_id) ) self.assertCountEqual(extremities, expected_extremities) - def test_prune_gap(self): + def test_prune_gap(self) -> None: """Test that we drop extremities after a gap when we see an event from the same domain. """ @@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) - def test_do_not_prune_gap_if_state_different(self): + def test_do_not_prune_gap_if_state_different(self) -> None: """Test that we don't prune extremities after a gap if the resolved state is different. """ @@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) - def test_prune_gap_if_old(self): + def test_prune_gap_if_old(self) -> None: """Test that we drop extremities after a gap when the previous extremity is "old" """ @@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) - def test_do_not_prune_gap_if_other_server(self): + def test_do_not_prune_gap_if_other_server(self) -> None: """Test that we do not drop extremities after a gap when we see an event from a different domain. """ @@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) - def test_prune_gap_if_dummy_remote(self): + def test_prune_gap_if_dummy_remote(self) -> None: """Test that we drop extremities after a gap when the previous extremity is a local dummy event and only points to remote events. """ @@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) - def test_prune_gap_if_dummy_local(self): + def test_prune_gap_if_dummy_local(self) -> None: """Test that we don't drop extremities after a gap when the previous extremity is a local dummy event and points to local events. """ @@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id, local_message_event_id]) - def test_do_not_prune_gap_if_not_dummy(self): + def test_do_not_prune_gap_if_not_dummy(self) -> None: """Test that we do not drop extremities after a gap when the previous extremity is not a dummy event. """ @@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main - def test_remote_user_rooms_cache_invalidated(self): + def test_remote_user_rooms_cache_invalidated(self) -> None: """Test that if the server leaves a room the `get_rooms_for_user` cache is invalidated for remote users. """ @@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) self.assertEqual(set(rooms), set()) - def test_room_remote_user_cache_invalidated(self): + def test_room_remote_user_cache_invalidated(self) -> None: """Test that if the server leaves a room the `get_users_in_room` cache is invalidated for remote users. """ diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 9059095525..aa4b5bd3b1 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -13,6 +13,7 @@ # limitations under the License. import signedjson.key +import signedjson.types import unpaddedbase64 from twisted.internet.defer import Deferred @@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -def decode_verify_key_base64(key_id: str, key_base64: str): +def decode_verify_key_base64( + key_id: str, key_base64: str +) -> signedjson.types.VerifyKey: key_bytes = unpaddedbase64.decode_base64(key_base64) return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) @@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): - def test_get_server_verify_keys(self): + def test_get_server_verify_keys(self) -> None: store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" @@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) - def test_cache(self): + def test_cache(self) -> None: """Check that updates correctly invalidate the cache.""" store = self.hs.get_datastores().main diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index c55c4db970..2827738379 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.reactor.advance(FORTY_DAYS) @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) - def test_initialise_reserved_users(self): + def test_initialise_reserved_users(self) -> None: threepids = self.hs.config.server.mau_limits_reserved_threepids # register three users, of which two have reserved 3pids, and a third @@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): active_count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(active_count, 3) - def test_can_insert_and_count_mau(self): + def test_can_insert_and_count_mau(self) -> None: count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 0) @@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 1) - def test_appservice_user_not_counted_in_mau(self): + def test_appservice_user_not_counted_in_mau(self) -> None: self.get_success( self.store.register_user( user_id="@appservice_user:server", appservice_id="wibble" @@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 0) - def test_user_last_seen_monthly_active(self): + def test_user_last_seen_monthly_active(self) -> None: user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" @@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertIsNone(result) @override_config({"max_mau_value": 5}) - def test_reap_monthly_active_users(self): + def test_reap_monthly_active_users(self) -> None: initial_users = 10 for i in range(initial_users): self.get_success( @@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Note that below says mau_limit (no s), this is the name of the config # value, although it gets stored on the config object as mau_limits. @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) - def test_reap_monthly_active_users_reserved_users(self): + def test_reap_monthly_active_users_reserved_users(self) -> None: """Tests that reaping correctly handles reaping where reserved users are present""" threepids = self.hs.config.server.mau_limits_reserved_threepids @@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, self.hs.config.server.max_mau_value) - def test_populate_monthly_users_is_guest(self): + def test_populate_monthly_users_is_guest(self) -> None: # Test that guest users are not added to mau list user_id = "@user_id:host" @@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() - def test_populate_monthly_users_should_update(self): + def test_populate_monthly_users_should_update(self) -> None: self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] @@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() - def test_populate_monthly_users_should_not_update(self): + def test_populate_monthly_users_should_not_update(self) -> None: self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] @@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() - def test_get_reserved_real_user_account(self): + def test_get_reserved_real_user_account(self) -> None: # Test no reserved users, or reserved threepids users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), 0) @@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) - def test_support_user_not_add_to_mau_limits(self): + def test_support_user_not_add_to_mau_limits(self) -> None: support_user_id = "@support:test" count = self.get_success(self.store.get_monthly_active_count()) @@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config( {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} ) - def test_track_monthly_users_without_cap(self): + def test_track_monthly_users_without_cap(self) -> None: count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(0, count) @@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEqual(2, count) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) - def test_no_users_when_not_tracking(self): + def test_no_users_when_not_tracking(self) -> None: self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.store.upsert_monthly_active_user.assert_not_called() - def test_get_monthly_active_count_by_service(self): + def test_get_monthly_active_count_by_service(self) -> None: appservice1_user1 = "@appservice1_user1:example.com" appservice1_user2 = "@appservice1_user2:example.com" @@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEqual(result[service2], 1) self.assertEqual(result[native], 1) - def test_get_monthly_active_users_by_service(self): + def test_get_monthly_active_users_by_service(self) -> None: # (No users, no filtering) -> empty result result = self.get_success(self.store.get_monthly_active_users_by_service()) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 9c1182ed16..010cc74c31 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import NotFoundError, SynapseError from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase): user_id = "@red:server" servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main self._storage_controllers = self.hs.get_storage_controllers() - def test_purge_history(self): + def test_purge_history(self) -> None: """ Purging a room history will delete everything before the topological point. """ @@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase): self.get_failure(self.store.get_event(third["event_id"]), NotFoundError) self.get_success(self.store.get_event(last["event_id"])) - def test_purge_history_wont_delete_extrems(self): + def test_purge_history_wont_delete_extrems(self) -> None: """ Purging a room history will delete everything before the topological point. """ @@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) + assert token.topological is not None event = f"t{token.topological + 1}-{token.stream + 1}" # Purge everything before this topological token @@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase): self.get_success(self.store.get_event(third["event_id"])) self.get_success(self.store.get_event(last["event_id"])) - def test_purge_room(self): + def test_purge_room(self) -> None: """ Purging a room will delete everything about it. """ diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 81253d0361..d8d84152dc 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -14,8 +14,12 @@ from typing import Collection, Optional +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import ReceiptTypes +from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock from tests.test_utils.event_injection import create_event from tests.unittest import HomeserverTestCase @@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test" class ReceiptTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver) -> None: + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: super().prepare(reactor, clock, homeserver) self.store = homeserver.get_datastores().main @@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase): ) self.assertEqual(res, {}) - res = self.get_last_unthreaded_receipt( + res2 = self.get_last_unthreaded_receipt( [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) - self.assertEqual(res, None) + self.assertIsNone(res2) def test_get_receipts_for_user(self) -> None: # Send some events into the first room diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 6c4e63b77c..df4740f9d9 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -11,27 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, cast from canonicaljson import json +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.types import RoomID, UserID +from synapse.events import EventBase, _EventInternalMetadata +from synapse.events.builder import EventBuilder +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomID, UserID +from synapse.util import Clock from tests import unittest from tests.utils import create_room class RedactionTestCase(unittest.HomeserverTestCase): - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config["redaction_retention_period"] = "30d" return config - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self._storage = hs.get_storage_controllers() + storage = hs.get_storage_controllers() + assert storage.persistence is not None + self._persistence = storage.persistence self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.depth = 1 - def inject_room_member( + def inject_room_member( # type: ignore[override] self, - room, - user, - membership, - replaces_state=None, - extra_content: Optional[dict] = None, - ): + room: RoomID, + user: UserID, + membership: str, + extra_content: Optional[JsonDict] = None, + ) -> EventBase: content = {"membership": membership} content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( @@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self._storage.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) return event - def inject_message(self, room, user, body): + def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase: self.depth += 1 builder = self.event_builder_factory.for_room_version( @@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self._storage.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) return event - def inject_redaction(self, room, event_id, user, reason): + def inject_redaction( + self, room: RoomID, event_id: str, user: UserID, reason: str + ) -> EventBase: builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { @@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self._storage.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) return event - def test_redact(self): + def test_redact(self) -> None: self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) msg_event = self.inject_message(self.room1, self.u_alice, "t") @@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): event.unsigned["redacted_because"], ) - def test_redact_join(self): + def test_redact_join(self) -> None: self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) msg_event = self.inject_room_member( @@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): event.unsigned["redacted_because"], ) - def test_circular_redaction(self): + def test_circular_redaction(self) -> None: redaction_event_id1 = "$redaction1_id:test" redaction_event_id2 = "$redaction2_id:test" class EventIdManglingBuilder: - def __init__(self, base_builder, event_id): + def __init__(self, base_builder: EventBuilder, event_id: str): self._base_builder = base_builder self._event_id = event_id @@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase): prev_event_ids: List[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, - ): + ) -> EventBase: built_event = await self._base_builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids ) - built_event._event_id = self._event_id + built_event._event_id = self._event_id # type: ignore[attr-defined] built_event._dict["event_id"] = self._event_id assert built_event.event_id == self._event_id return built_event @property - def room_id(self): + def room_id(self) -> str: return self._base_builder.room_id @property - def type(self): + def type(self) -> str: return self._base_builder.type @property - def internal_metadata(self): + def internal_metadata(self) -> _EventInternalMetadata: return self._base_builder.internal_metadata event_1, context_1 = self.get_success( self.event_creation_handler.create_new_client_event( - EventIdManglingBuilder( - self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Redaction, - "sender": self.u_alice.to_string(), - "room_id": self.room1.to_string(), - "content": {"reason": "test"}, - "redacts": redaction_event_id2, - }, + cast( + EventBuilder, + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, ), - redaction_event_id1, ) ) ) - self.get_success(self._storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( - EventIdManglingBuilder( - self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Redaction, - "sender": self.u_alice.to_string(), - "room_id": self.room1.to_string(), - "content": {"reason": "test"}, - "redacts": redaction_event_id1, - }, + cast( + EventBuilder, + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, ), - redaction_event_id2, ) ) ) - self.get_success(self._storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): fetched.unsigned["redacted_because"].event_id, redaction_event_id2 ) - def test_redact_censor(self): + def test_redact_censor(self) -> None: """Test that a redacted event gets censored in the DB after a month""" self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.assert_dict({"content": {}}, json.loads(event_json)) - def test_redact_redaction(self): + def test_redact_redaction(self) -> None: """Tests that we can redact a redaction and can fetch it again.""" self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.store.get_event(first_redact_event.event_id, allow_none=True) ) - def test_store_redacted_redaction(self): + def test_store_redacted_redaction(self) -> None: """Tests that we can store a redacted redaction.""" self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success( - self._storage.persistence.persist_event(redaction_event, context) - ) + self.get_success(self._persistence.persist_event(redaction_event, context)) # Now lets jump to the future where we have censored the redaction event # in the DB. diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index 0baa54312e..966aafea6f 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -14,10 +14,15 @@ from typing import List from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.app.generic_worker import GenericWorkerServer +from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database from synapse.storage.schema import SCHEMA_VERSION +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]: class WorkerSchemaTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs - def default_config(self): + def default_config(self) -> JsonDict: conf = super().default_config() # Mark this as a worker app. @@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase): return conf - def test_rolling_back(self): + def test_rolling_back(self) -> None: """Test that workers can start if the DB is a newer schema version""" db_pool = self.hs.get_datastores().main.db_pool @@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase): prepare_database(db_conn, db_pool.engine, self.hs.config) - def test_not_upgraded_old_schema_version(self): + def test_not_upgraded_old_schema_version(self) -> None: """Test that workers don't start if the DB has an older schema version""" db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( @@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): with self.assertRaises(PrepareDatabaseException): prepare_database(db_conn, db_pool.engine, self.hs.config) - def test_not_upgraded_current_schema_version_with_outstanding_deltas(self): + def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None: """ Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3405efb6a8..71ec74eadc 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer from synapse.types import RoomAlias, RoomID, UserID +from synapse.util import Clock from tests.unittest import HomeserverTestCase class RoomStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table self.store = hs.get_datastores().main @@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase): ) ) - def test_get_room(self): + def test_get_room(self) -> None: + res = self.get_success(self.store.get_room(self.room.to_string())) + assert res is not None self.assertDictContainsSubset( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "is_public": True, }, - (self.get_success(self.store.get_room(self.room.to_string()))), + res, ) - def test_get_room_unknown_room(self): + def test_get_room_unknown_room(self) -> None: self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) - def test_get_room_with_stats(self): + def test_get_room_with_stats(self) -> None: + res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) + assert res is not None self.assertDictContainsSubset( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "public": True, }, - (self.get_success(self.store.get_room_with_stats(self.room.to_string()))), + res, ) - def test_get_room_with_stats_unknown_room(self): + def test_get_room_with_stats_unknown_room(self) -> None: self.assertIsNone( - (self.get_success(self.store.get_room_with_stats("!uknown:test"))), + self.get_success(self.store.get_room_with_stats("!uknown:test")) ) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index ef850daa73..14d872514d 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase): room.register_servlets, ] - def test_null_byte(self): + def test_null_byte(self) -> None: """ Postgres/SQLite don't like null bytes going into the search tables. Internally we replace those with a space. @@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase): if isinstance(store.database_engine, PostgresEngine): self.assertIn("alice", result.get("highlights")) - def test_non_string(self): + def test_non_string(self) -> None: """Test that non-string `value`s are not inserted into `event_search`. This is particularly important when using sqlite, since a sqlite column can hold @@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase): self.assertEqual(f.value.code, 404) @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") - def test_sqlite_non_string_deletion_background_update(self): + def test_sqlite_non_string_deletion_background_update(self) -> None: """Test the background update to delete bad rows from `event_search`.""" store = self.hs.get_datastores().main @@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase): "results array length should match count", ) - def test_postgres_web_search_for_phrase(self): + def test_postgres_web_search_for_phrase(self) -> None: """ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. This test is skipped unless the postgres instance supports websearch_to_tsquery. @@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase): self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) - def test_sqlite_search(self): + def test_sqlite_search(self) -> None: """ Test sqlite searching for phrases. """ diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5564161750..d4e6d4236c 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -16,10 +16,15 @@ import logging from frozendict import frozendict +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.server import HomeServer from synapse.storage.state import StateFilter -from synapse.types import RoomID, UserID +from synapse.types import JsonDict, RoomID, StateMap, UserID +from synapse.util import Clock from tests.unittest import HomeserverTestCase, TestCase @@ -27,7 +32,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state @@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase): ) ) - def inject_state_event(self, room, sender, typ, state_key, content): + def inject_state_event( + self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict + ) -> EventBase: builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { @@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) + assert self.storage.persistence is not None self.get_success(self.storage.persistence.persist_event(event, context)) return event - def assertStateMapEqual(self, s1, s2): + def assertStateMapEqual( + self, s1: StateMap[EventBase], s2: StateMap[EventBase] + ) -> None: for t in s1: # just compare event IDs for simplicity self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) - def test_get_state_groups_ids(self): + def test_get_state_groups_ids(self) -> None: e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) + self.storage.state.get_state_groups_ids( + self.room.to_string(), [e2.event_id] + ) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase): {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) - def test_get_state_groups(self): + def test_get_state_groups(self) -> None: e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = self.get_success( - self.storage.state.get_state_groups(self.room, [e2.event_id]) + self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) - def test_get_state_for_event(self): + def test_get_state_for_event(self) -> None: # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) @@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase): class StateFilterDifferenceTestCase(TestCase): def assert_difference( self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter - ): + ) -> None: self.assertEqual( minuend.approx_difference(subtrahend), expected, f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", ) - def test_state_filter_difference_no_include_other_minus_no_include_other(self): + def test_state_filter_difference_no_include_other_minus_no_include_other( + self, + ) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), both a and b do not have the @@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_include_other_minus_no_include_other(self): + def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), only a has the include_others flag set. @@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_include_other_minus_include_other(self): + def test_state_filter_difference_include_other_minus_include_other(self) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), both a and b have the include_others @@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_no_include_other_minus_include_other(self): + def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: """ Tests the StateFilter.approx_difference method where, in a.approx_difference(b), only b has the include_others flag set. @@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase): ), ) - def test_state_filter_difference_simple_cases(self): + def test_state_filter_difference_simple_cases(self) -> None: """ Tests some very simple cases of the StateFilter approx_difference, that are not explicitly tested by the more in-depth tests. @@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase): class StateFilterTestCase(TestCase): - def test_return_expanded(self): + def test_return_expanded(self) -> None: """ Tests the behaviour of the return_expanded() function that expands StateFilters to include more state types (for the sake of cache hit rate). diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 34fa810cf6..bc090ebce0 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -14,11 +14,15 @@ from typing import List +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase): login.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config["experimental_features"] = {"msc3874_enabled": True} return config - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase): return [ev.event_id for ev in events] - def test_filter_relation_senders(self): + def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) @@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase): chunk = self._filter_messages(filter) self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) - def test_filter_relation_type(self): + def test_filter_relation_type(self) -> None: # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) @@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase): chunk = self._filter_messages(filter) self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) - def test_filter_relation_senders_and_type(self): + def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { "related_by_senders": [self.second_user_id], @@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase): chunk = self._filter_messages(filter) self.assertEqual(chunk, [self.event_id_1]) - def test_duplicate_relation(self): + def test_duplicate_relation(self) -> None: """An event should only be returned once if there are multiple relations to it.""" self.helper.send_event( room_id=self.room_id, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index e05daa285e..db9ee9955e 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -12,17 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer from synapse.storage.databases.main.transactions import DestinationRetryTimings +from synapse.util import Clock from synapse.util.retryutils import MAX_RETRY_INTERVAL from tests.unittest import HomeserverTestCase class TransactionStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main - def test_get_set_transactions(self): + def test_get_set_transactions(self) -> None: """Tests that we can successfully get a non-existent entry for destination retries, as well as testing tht we can set and get correctly. @@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase): r, ) - def test_initial_set_transactions(self): + def test_initial_set_transactions(self) -> None: """Tests that we can successfully set the destination retries (there was a bug around invalidating the cache that broke this) """ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) - def test_large_destination_retry(self): + def test_large_destination_retry(self) -> None: d = self.store.set_destination_retry_timings( "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL ) self.get_success(d) - d = self.store.get_destination_retry_timings("example.com") - self.get_success(d) + d2 = self.store.get_destination_retry_timings("example.com") + self.get_success(d2) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index ace82cbf42..15ea4770bd 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -12,21 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.types import Cursor +from synapse.util import Clock + from tests import unittest class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): """Test SQL transaction limit doesn't break transactions.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(db_txn_limit=1000) - def test_config(self): + def test_config(self) -> None: db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) - def test_select(self): - def do_select(txn): + def test_select(self) -> None: + def do_select(txn: Cursor) -> None: txn.execute("SELECT 1") db_pool = self.hs.get_datastores().databases[0] diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index cae14151c0..0e3fc2a77f 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Collection, Dict from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred @@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase): # the results to be returned by the mocked get_partial_state_events self._events_dict: Dict[str, bool] = {} - async def get_partial_state_events(events): + async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]: return {e: self._events_dict[e] for e in events} self.mock_store = mock.Mock(spec_set=["get_partial_state_events"]) @@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker = PartialStateEventsTracker(self.mock_store) - def test_does_not_block_for_full_state_events(self): + def test_does_not_block_for_full_state_events(self) -> None: self._events_dict = {"event1": False, "event2": False} self.successResultOf( @@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase): ["event1", "event2"] ) - def test_blocks_for_partial_state_events(self): + def test_blocks_for_partial_state_events(self) -> None: self._events_dict = {"event1": True, "event2": False} d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) @@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d) - def test_un_partial_state_race(self): + def test_un_partial_state_race(self) -> None: # if the event is un-partial-stated between the initial check and the # registration of the listener, it should not block. self._events_dict = {"event1": True, "event2": False} - async def get_partial_state_events(events): + async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]: res = {e: self._events_dict[e] for e in events} # change the result for next time self._events_dict = {"event1": False, "event2": False} @@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase): ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) ) - def test_un_partial_state_during_get_partial_state_events(self): + def test_un_partial_state_during_get_partial_state_events(self) -> None: # we should correctly handle a call to notify_un_partial_stated during the # second call to get_partial_state_events. self._events_dict = {"event1": True, "event2": False} - async def get_partial_state_events1(events): + async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]: self.mock_store.get_partial_state_events.side_effect = ( get_partial_state_events2 ) return {e: self._events_dict[e] for e in events} - async def get_partial_state_events2(events): + async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]: self.tracker.notify_un_partial_stated("event1") self._events_dict["event1"] = False return {e: self._events_dict[e] for e in events} @@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase): ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) ) - def test_cancellation(self): + def test_cancellation(self) -> None: self._events_dict = {"event1": True, "event2": False} d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) @@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.tracker = PartialCurrentStateTracker(self.mock_store) - def test_does_not_block_for_full_state_rooms(self): + def test_does_not_block_for_full_state_rooms(self) -> None: self.mock_store.is_partial_state_room.return_value = make_awaitable(False) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) - def test_blocks_for_partial_room_state(self): + def test_blocks_for_partial_room_state(self) -> None: self.mock_store.is_partial_state_room.return_value = make_awaitable(True) d = ensureDeferred(self.tracker.await_full_state("room_id")) @@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("room_id") self.successResultOf(d) - def test_un_partial_state_race(self): + def test_un_partial_state_race(self) -> None: # We should correctly handle race between awaiting the state and us # un-partialling the state - async def is_partial_state_room(events): + async def is_partial_state_room(room_id: str) -> bool: self.tracker.notify_un_partial_stated("room_id") return True @@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) - def test_cancellation(self): + def test_cancellation(self) -> None: self.mock_store.is_partial_state_room.return_value = make_awaitable(True) d1 = ensureDeferred(self.tracker.await_full_state("room_id")) -- cgit 1.5.1 From 373c485d8c7f39206bac60c6ef313b4a1978bbc0 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 9 Dec 2022 23:02:11 +0000 Subject: Handle half-created indices in receipts index background update (#14650) When Synapse is terminated while running the background update to create the `receipts_graph` or `receipts_linearized` indexes, the indexes may be successfully created (or marked as invalid on postgres) while the background update remains unfinished. When Synapse next starts up, the background update will fail because the index already exists, or exists but is invalid on postgres. Use the existing code to create indices in background updates, since it handles these edge cases. Signed-off-by: Sean Quah --- changelog.d/14650.bugfix | 2 ++ synapse/storage/background_updates.py | 55 +++++++++++++++++++++++++----- synapse/storage/databases/main/receipts.py | 51 +++++++-------------------- 3 files changed, 60 insertions(+), 48 deletions(-) create mode 100644 changelog.d/14650.bugfix (limited to 'synapse') diff --git a/changelog.d/14650.bugfix b/changelog.d/14650.bugfix new file mode 100644 index 0000000000..5e18641bf7 --- /dev/null +++ b/changelog.d/14650.bugfix @@ -0,0 +1,2 @@ +Fix a bug introduced in Synapse 1.72.0 where the background updates to add non-thread unique indexes on receipts would fail if they were previously interrupted. + diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 2056ecb2c3..a99aea8926 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -544,6 +544,48 @@ class BackgroundUpdater: The named index will be dropped upon completion of the new index. """ + async def updater(progress: JsonDict, batch_size: int) -> int: + await self.create_index_in_background( + index_name=index_name, + table=table, + columns=columns, + where_clause=where_clause, + unique=unique, + psql_only=psql_only, + replaces_index=replaces_index, + ) + await self._end_background_update(update_name) + return 1 + + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + updater, oneshot=True + ) + + async def create_index_in_background( + self, + index_name: str, + table: str, + columns: Iterable[str], + where_clause: Optional[str] = None, + unique: bool = False, + psql_only: bool = False, + replaces_index: Optional[str] = None, + ) -> None: + """Add an index in the background. + + Args: + update_name: update_name to register for + index_name: name of index to add + table: table to add index to + columns: columns/expressions to include in index + where_clause: A WHERE clause to specify a partial unique index. + unique: true to make a UNIQUE index + psql_only: true to only create this index on psql databases (useful + for virtual sqlite tables) + replaces_index: The name of an index that this index replaces. + The named index will be dropped upon completion of the new index. + """ + def create_index_psql(conn: Connection) -> None: conn.rollback() # postgres insists on autocommit for the index @@ -618,16 +660,11 @@ class BackgroundUpdater: else: runner = create_index_sqlite - async def updater(progress: JsonDict, batch_size: int) -> int: - if runner is not None: - logger.info("Adding index %s to %s", index_name, table) - await self.db_pool.runWithConnection(runner) - await self._end_background_update(update_name) - return 1 + if runner is None: + return - self._background_update_handlers[update_name] = _BackgroundUpdateHandler( - updater, oneshot=True - ) + logger.info("Adding index %s to %s", index_name, table) + await self.db_pool.runWithConnection(runner) async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index a580e4bdda..e06725f69c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -924,39 +924,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): return batch_size - async def _create_receipts_index(self, index_name: str, table: str) -> None: - """Adds a unique index on `(room_id, receipt_type, user_id)` to the given - receipts table, for non-thread receipts.""" - - def _create_index(conn: LoggingDatabaseConnection) -> None: - conn.rollback() - - # we have to set autocommit, because postgres refuses to - # CREATE INDEX CONCURRENTLY without it. - if isinstance(self.database_engine, PostgresEngine): - conn.set_session(autocommit=True) - - try: - c = conn.cursor() - - # Now that the duplicates are gone, we can create the index. - concurrently = ( - "CONCURRENTLY" - if isinstance(self.database_engine, PostgresEngine) - else "" - ) - sql = f""" - CREATE UNIQUE INDEX {concurrently} {index_name} - ON {table}(room_id, receipt_type, user_id) - WHERE thread_id IS NULL - """ - c.execute(sql) - finally: - if isinstance(self.database_engine, PostgresEngine): - conn.set_session(autocommit=False) - - await self.db_pool.runWithConnection(_create_index) - async def _background_receipts_linearized_unique_index( self, progress: dict, batch_size: int ) -> int: @@ -999,9 +966,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): _remote_duplicate_receipts_txn, ) - await self._create_receipts_index( - "receipts_linearized_unique_index", - "receipts_linearized", + await self.db_pool.updates.create_index_in_background( + index_name="receipts_linearized_unique_index", + table="receipts_linearized", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, ) await self.db_pool.updates._end_background_update( @@ -1050,9 +1020,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): _remote_duplicate_receipts_txn, ) - await self._create_receipts_index( - "receipts_graph_unique_index", - "receipts_graph", + await self.db_pool.updates.create_index_in_background( + index_name="receipts_graph_unique_index", + table="receipts_graph", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, ) await self.db_pool.updates._end_background_update( -- cgit 1.5.1 From 2a3cd59dd06411a79fb7500970db1b98f0d87695 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 12 Dec 2022 13:21:17 +0100 Subject: Add optional ICU support for user search (#14464) Fixes #13655 This change uses ICU (International Components for Unicode) to improve boundary detection in user search. This change also adds a new dependency on libicu-dev and pkg-config for the Debian packages, which are available in all supported distros. --- changelog.d/14464.feature | 1 + debian/changelog | 7 +++ debian/control | 2 + docker/Dockerfile | 2 + docker/Dockerfile-dhvirtualenv | 2 + poetry.lock | 16 +++++- pyproject.toml | 7 +++ stubs/icu.pyi | 25 +++++++++ synapse/storage/databases/main/user_directory.py | 67 ++++++++++++++++++++++-- tests/storage/test_user_directory.py | 43 +++++++++++++++ 10 files changed, 166 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14464.feature create mode 100644 stubs/icu.pyi (limited to 'synapse') diff --git a/changelog.d/14464.feature b/changelog.d/14464.feature new file mode 100644 index 0000000000..688ea32117 --- /dev/null +++ b/changelog.d/14464.feature @@ -0,0 +1 @@ +Improve user search for international display names. diff --git a/debian/changelog b/debian/changelog index 163b7210bf..5d3c4f7d6b 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +matrix-synapse-py3 (1.74.0~rc1) UNRELEASED; urgency=medium + + * New dependency on libicu-dev to provide improved results for user + search. + + -- Synapse Packaging team Tue, 06 Dec 2022 15:28:10 +0000 + matrix-synapse-py3 (1.73.0) stable; urgency=medium * New Synapse release 1.73.0. diff --git a/debian/control b/debian/control index 86f5a66d02..bc628cec08 100644 --- a/debian/control +++ b/debian/control @@ -8,6 +8,8 @@ Build-Depends: dh-virtualenv (>= 1.1), libsystemd-dev, libpq-dev, + libicu-dev, + pkg-config, lsb-release, python3-dev, python3, diff --git a/docker/Dockerfile b/docker/Dockerfile index 185d5bc3d4..7e5123210a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -97,6 +97,8 @@ RUN \ zlib1g-dev \ git \ curl \ + libicu-dev \ + pkg-config \ && rm -rf /var/lib/apt/lists/* diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index 73165f6f85..f3b5b00ce6 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -84,6 +84,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \ python3-venv \ sqlite3 \ libpq-dev \ + libicu-dev \ + pkg-config \ xmlsec1 # Install rust and ensure it's in the PATH diff --git a/poetry.lock b/poetry.lock index cac22e2ef0..ccda8a23fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -837,6 +837,14 @@ category = "dev" optional = false python-versions = ">=3.5" +[[package]] +name = "pyicu" +version = "2.10.2" +description = "Python extension wrapping the ICU C++ API" +category = "main" +optional = true +python-versions = "*" + [[package]] name = "pyjwt" version = "2.4.0" @@ -1622,7 +1630,7 @@ docs = ["Sphinx", "repoze.sphinx.autointerface"] test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"] [extras] -all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler"] +all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler", "pyicu"] cache-memory = ["Pympler"] jwt = ["authlib"] matrix-synapse-ldap3 = ["matrix-synapse-ldap3"] @@ -1635,11 +1643,12 @@ sentry = ["sentry-sdk"] systemd = ["systemd-python"] test = ["parameterized", "idna"] url-preview = ["lxml"] +user-search = ["pyicu"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263" +content-hash = "f20007013f33bc35a01e412c48adc62a936030f3074e06286674c5ad7f44d300" [metadata.files] attrs = [ @@ -2427,6 +2436,9 @@ pygments = [ {file = "Pygments-2.11.2-py3-none-any.whl", hash = "sha256:44238f1b60a76d78fc8ca0528ee429702aae011c265fe6a8dd8b63049ae41c65"}, {file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"}, ] +pyicu = [ + {file = "PyICU-2.10.2.tar.gz", hash = "sha256:0c3309eea7fab6857507ace62403515b60fe096cbfb4f90d14f55ff75c5441c1"}, +] pyjwt = [ {file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"}, {file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"}, diff --git a/pyproject.toml b/pyproject.toml index df59fa0562..bb383683cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,6 +208,7 @@ hiredis = { version = "*", optional = true } Pympler = { version = "*", optional = true } parameterized = { version = ">=0.7.4", optional = true } idna = { version = ">=2.5", optional = true } +pyicu = { version = ">=2.10.2", optional = true } [tool.poetry.extras] # NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified @@ -230,6 +231,10 @@ redis = ["txredisapi", "hiredis"] # Required to use experimental `caches.track_memory_usage` config option. cache-memory = ["pympler"] test = ["parameterized", "idna"] +# Allows for better search for international characters in the user directory. This +# requires libicu's development headers installed on the system (e.g. libicu-dev on +# Debian-based distributions). +user-search = ["pyicu"] # The duplication here is awful. I hate hate hate hate hate it. However, for now I want # to ensure you can still `pip install matrix-synapse[all]` like today. Two motivations: @@ -261,6 +266,8 @@ all = [ "txredisapi", "hiredis", # cache-memory "pympler", + # improved user search + "pyicu", # omitted: # - test: it's useful to have this separate from dev deps in the olddeps job # - systemd: this is a system-based requirement diff --git a/stubs/icu.pyi b/stubs/icu.pyi new file mode 100644 index 0000000000..efeda7938a --- /dev/null +++ b/stubs/icu.pyi @@ -0,0 +1,25 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stub for PyICU. + +class Locale: + @staticmethod + def getDefault() -> Locale: ... + +class BreakIterator: + @staticmethod + def createWordInstance(locale: Locale) -> BreakIterator: ... + def setText(self, text: str) -> None: ... + def nextBoundary(self) -> int: ... diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af9952f513..14ef5b040d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,14 @@ from typing import ( cast, ) +try: + # Figure out if ICU support is available for searching users. + import icu + + USE_ICU = True +except ModuleNotFoundError: + USE_ICU = False + from typing_extensions import TypedDict from synapse.api.errors import StoreError @@ -900,7 +908,7 @@ def _parse_query_sqlite(search_term: str) -> str: """ # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + results = _parse_words(search_term) return " & ".join("(%s* OR %s)" % (result, result) for result in results) @@ -910,12 +918,63 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + results = _parse_words(search_term) both = " & ".join("(%s:* | %s)" % (result, result) for result in results) exact = " & ".join("%s" % (result,) for result in results) prefix = " & ".join("%s:*" % (result,) for result in results) return both, exact, prefix + + +def _parse_words(search_term: str) -> List[str]: + """Split the provided search string into a list of its words. + + If support for ICU (International Components for Unicode) is available, use it. + Otherwise, fall back to using a regex to detect word boundaries. This latter + solution works well enough for most latin-based languages, but doesn't work as well + with other languages. + + Args: + search_term: The search string. + + Returns: + A list of the words in the search string. + """ + if USE_ICU: + return _parse_words_with_icu(search_term) + + return re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + +def _parse_words_with_icu(search_term: str) -> List[str]: + """Break down the provided search string into its individual words using ICU + (International Components for Unicode). + + Args: + search_term: The search string. + + Returns: + A list of the words in the search string. + """ + results = [] + breaker = icu.BreakIterator.createWordInstance(icu.Locale.getDefault()) + breaker.setText(search_term) + i = 0 + while True: + j = breaker.nextBoundary() + if j < 0: + break + + result = search_term[i:j] + + # libicu considers spaces and punctuation between words as words, but we don't + # want to include those in results as they would result in syntax errors in SQL + # queries (e.g. "foo bar" would result in the search query including "foo & & + # bar"). + if len(re.findall(r"([\w\-]+)", result, re.UNICODE)): + results.append(result) + + i = j + + return results diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 88c7d5fec0..3ba896ecf3 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Any, Dict, Set, Tuple from unittest import mock from unittest.mock import Mock, patch @@ -30,6 +31,12 @@ from synapse.util import Clock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config +try: + import icu +except ImportError: + icu = None # type: ignore + + ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" @@ -467,3 +474,39 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): r["results"][0], {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + + +class UserDirectoryICUTestCase(HomeserverTestCase): + if not icu: + skip = "Requires PyICU" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.user_dir_helper = GetUserDirectoryTables(self.store) + + def test_icu_word_boundary(self) -> None: + """Tests that we correctly detect word boundaries when ICU (International + Components for Unicode) support is available. + """ + + display_name = "Gáo" + + # This word is not broken down correctly by Python's regular expressions, + # likely because á is actually a lowercase a followed by a U+0301 combining + # acute accent. This is specifically something that ICU support fixes. + matches = re.findall(r"([\w\-]+)", display_name, re.UNICODE) + self.assertEqual(len(matches), 2) + + self.get_success( + self.store.update_profile_in_user_dir(ALICE, display_name, None) + ) + self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE,))) + + # Check that searching for this user yields the correct result. + r = self.get_success(self.store.search_user_dir(BOB, display_name, 10)) + self.assertFalse(r["limited"]) + self.assertEqual(len(r["results"]), 1) + self.assertDictEqual( + r["results"][0], + {"user_id": ALICE, "display_name": display_name, "avatar_url": None}, + ) -- cgit 1.5.1 From 74b89c27613a34ec9b291ad3066db7ce0adff1db Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 12 Dec 2022 13:55:23 +0000 Subject: Revert the deletion of stale devices due to performance issues. (#14662) --- changelog.d/14595.misc | 1 - changelog.d/14649.misc | 1 - changelog.d/14662.removal | 1 + synapse/handlers/device.py | 33 +----------- synapse/storage/databases/main/devices.py | 84 +------------------------------ tests/handlers/test_device.py | 33 +----------- tests/storage/test_client_ips.py | 4 +- 7 files changed, 5 insertions(+), 152 deletions(-) delete mode 100644 changelog.d/14595.misc delete mode 100644 changelog.d/14649.misc create mode 100644 changelog.d/14662.removal (limited to 'synapse') diff --git a/changelog.d/14595.misc b/changelog.d/14595.misc deleted file mode 100644 index f9bfc581ad..0000000000 --- a/changelog.d/14595.misc +++ /dev/null @@ -1 +0,0 @@ -Prune user's old devices on login if they have too many. diff --git a/changelog.d/14649.misc b/changelog.d/14649.misc deleted file mode 100644 index f9bfc581ad..0000000000 --- a/changelog.d/14649.misc +++ /dev/null @@ -1 +0,0 @@ -Prune user's old devices on login if they have too many. diff --git a/changelog.d/14662.removal b/changelog.d/14662.removal new file mode 100644 index 0000000000..19a387bbb4 --- /dev/null +++ b/changelog.d/14662.removal @@ -0,0 +1 @@ +(remove from changelog: unreleased) Revert the deletion of stale devices due to performance issues. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c935c7be90..d4750a32e6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -52,7 +52,6 @@ from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable -from synapse.util.iterutils import batch_iter from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -422,9 +421,6 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) - # Prune the user's device list if they already have a lot of devices. - await self._prune_too_many_devices(user_id) - if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -456,33 +452,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _prune_too_many_devices(self, user_id: str) -> None: - """Delete any excess old devices this user may have.""" - device_ids = await self.store.check_too_many_devices_for_user(user_id, 100) - if not device_ids: - return - - logger.info("Pruning %d old devices for user %s", len(device_ids), user_id) - - # We don't want to block and try and delete tonnes of devices at once, - # so we cap the number of devices we delete synchronously. - first_batch, remaining_device_ids = device_ids[:10], device_ids[10:] - await self.delete_devices(user_id, first_batch) - - if not remaining_device_ids: - return - - # Now spawn a background loop that deletes the rest. - async def _prune_too_many_devices_loop() -> None: - for batch in batch_iter(remaining_device_ids, 10): - await self.delete_devices(user_id, batch) - - await self.clock.sleep(1) - - run_as_background_process( - "_prune_too_many_devices_loop", _prune_too_many_devices_loop - ) - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. @@ -512,7 +481,7 @@ class DeviceHandler(DeviceWorkerHandler): device_ids = [d for d in device_ids if d != except_device_id] await self.delete_devices(user_id, device_ids) - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Delete several devices Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 95d4c0622d..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1569,77 +1569,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def check_too_many_devices_for_user( - self, user_id: str, limit: int - ) -> List[str]: - """Check if the user has a lot of devices, and if so return the set of - devices we can prune. - - This does *not* return hidden devices or devices with E2E keys. - - Returns at most `limit` number of devices, ordered by last seen. - """ - - num_devices = await self.db_pool.simple_select_one_onecol( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcol="COALESCE(COUNT(*), 0)", - desc="count_devices", - ) - - # We let users have up to ten devices without pruning. - if num_devices <= 10: - return [] - - # We prune everything older than N days. - max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000 - - if num_devices > 50: - # If the user has more than 50 devices, then we chose a last seen - # that ensures we keep at most 50 devices. - sql = """ - SELECT last_seen FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen IS NOT NULL - AND key_json IS NULL - ORDER BY last_seen DESC - LIMIT 1 - OFFSET 50 - """ - - rows = await self.db_pool.execute( - "check_too_many_devices_for_user_last_seen", None, sql, (user_id,) - ) - if rows: - max_last_seen = max(rows[0][0], max_last_seen) - - # Now fetch the devices to delete. - sql = """ - SELECT device_id FROM devices - LEFT JOIN e2e_device_keys_json USING (user_id, device_id) - WHERE - user_id = ? - AND NOT hidden - AND last_seen < ? - AND key_json IS NULL - ORDER BY last_seen - LIMIT ? - """ - - def check_too_many_devices_for_user_txn( - txn: LoggingTransaction, - ) -> List[str]: - txn.execute(sql, (user_id, max_last_seen, limit)) - return [device_id for device_id, in txn] - - return await self.db_pool.runInteraction( - "check_too_many_devices_for_user", - check_too_many_devices_for_user_txn, - ) - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Because we have write access, this will be a StreamIdGenerator @@ -1698,7 +1627,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values={}, insertion_values={ "display_name": initial_device_display_name, - "last_seen": self._clock.time_msec(), "hidden": False, }, desc="store_device", @@ -1744,15 +1672,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @cached(max_entries=0) - async def delete_device(self, user_id: str, device_id: str) -> None: - raise NotImplementedError() - - # Note: sometimes deleting rows out of `device_inbox` can take a long time, - # so we use a cache so that we deduplicate in flight requests to delete - # devices. - @cachedList(cached_method_name="delete_device", list_name="device_ids") - async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> dict: + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: @@ -1789,8 +1709,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - return {} - async def update_device( self, user_id: str, device_id: str, new_display_name: Optional[str] = None ) -> None: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index e51cac9b33..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,8 +20,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler -from synapse.rest import admin -from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.util import Clock @@ -32,12 +30,6 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - admin.register_servlets, - account.register_servlets, - ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() @@ -123,7 +115,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, - "last_seen_ts": 1000000, + "last_seen_ts": None, }, device_map["xyz"], ) @@ -237,29 +229,6 @@ class DeviceTestCase(unittest.HomeserverTestCase): NotFoundError, ) - def test_login_delete_old_devices(self) -> None: - """Delete old devices if the user already has too many.""" - - user_id = self.register_user("user", "pass") - - # Create a bunch of devices - for _ in range(50): - self.login("user", "pass") - self.reactor.advance(1) - - # Advance the clock for ages (as we only delete old devices) - self.reactor.advance(60 * 60 * 24 * 300) - - # Log in again to start the pruning - self.login("user", "pass") - - # Give the background job time to do its thing - self.reactor.pump([1.0] * 100) - - # We should now only have the most recent device. - devices = self.get_success(self.handler.get_devices_by_user(user_id)) - self.assertEqual(len(devices), 1) - def _record_users(self) -> None: # check this works for both devices which have a recorded client_ip, # and those which don't. diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 81e4e596e4..7f7f4ef892 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -170,8 +170,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) ) - last_seen = self.clock.time_msec() - if after_persisting: # Trigger the storage loop self.reactor.advance(10) @@ -192,7 +190,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "device_id": device_id, "ip": None, "user_agent": None, - "last_seen": last_seen, + "last_seen": None, }, ], ) -- cgit 1.5.1 From b5b5f6608462a988b05502a3b70b6a57ca3846d2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:19:30 +0000 Subject: Move `StateFilter` to `synapse.types` (#14668) * Move `StateFilter` to `synapse.types` * Changelog --- changelog.d/14668.misc | 1 + synapse/events/builder.py | 2 +- synapse/events/snapshot.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 2 +- synapse/rest/client/room.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/controllers/persist_events.py | 2 +- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/state.py | 2 +- synapse/storage/databases/state/bg_updates.py | 2 +- synapse/storage/databases/state/store.py | 2 +- synapse/storage/state.py | 567 ---------------- synapse/types.py | 928 -------------------------- synapse/types/__init__.py | 928 ++++++++++++++++++++++++++ synapse/types/state.py | 567 ++++++++++++++++ synapse/visibility.py | 2 +- tests/storage/test_state.py | 2 +- 29 files changed, 1520 insertions(+), 1519 deletions(-) create mode 100644 changelog.d/14668.misc delete mode 100644 synapse/storage/state.py delete mode 100644 synapse/types.py create mode 100644 synapse/types/__init__.py create mode 100644 synapse/types/state.py (limited to 'synapse') diff --git a/changelog.d/14668.misc b/changelog.d/14668.misc new file mode 100644 index 0000000000..5269d8a97d --- /dev/null +++ b/changelog.d/14668.misc @@ -0,0 +1 @@ +Move `StateFilter` to `synapse.types`. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index d62906043f..94dd1298e1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -28,8 +28,8 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import EventID, JsonDict +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.stringutils import random_string diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 1c0e96bec7..6eaef8b57a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore - from synapse.storage.state import StateFilter + from synapse.types.state import StateFilter @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3398fcaf7d..b2784d7333 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -70,8 +70,8 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f7223b03c3..d2facdab60 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -75,7 +75,6 @@ from synapse.replication.http.federation import ( from synapse.state import StateResolutionStore from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, @@ -83,6 +82,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5cbe89f4fd..d6e90ef259 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,7 +59,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, PersistedEventPosition, @@ -70,6 +69,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index c572508a02..8c8ff18a1a 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6307fa9c5d..c611efb760 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -46,8 +46,8 @@ from synapse.replication.http.register import ( ReplicationRegisterServlet, ) from synapse.spam_checker_api import RegistrationBehaviour -from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester +from synapse.types.state import StateFilter if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6dcfd86fdf..f81241c2b3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -62,7 +62,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -77,6 +76,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import stringutils from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6ad2b38b8f..0c39e852a1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -34,7 +34,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, Requester, @@ -45,6 +44,7 @@ from synapse.types import ( create_requester, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index bcab98c6d5..33115ce488 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.storage.state import StateFilter from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client if TYPE_CHECKING: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dace9b606f..7d6a653747 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -49,7 +49,6 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.storage.state import StateFilter from synapse.types import ( DeviceListUpdates, JsonDict, @@ -61,6 +60,7 @@ from synapse.types import ( StreamToken, UserID, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 96a661177a..0092a03c59 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,7 +111,6 @@ from synapse.storage.background_updates import ( ) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.types import ( DomainSpecificString, JsonDict, @@ -124,6 +123,7 @@ from synapse.types import ( UserProfile, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import CachedFunction, cached diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9ed35d8461..36e5b327ef 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -35,8 +35,8 @@ from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership -from synapse.storage.state import StateFilter from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index c2575ba3d9..93b255ced5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -37,8 +37,8 @@ from synapse.push.push_types import ( TemplateVars, ) from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 747e6fda83..e957aa28ca 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,9 +34,9 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, RoomID, UserID, create_requester +from synapse.types.state import StateFilter from synapse.util import json_decoder if TYPE_CHECKING: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 514eb6afc8..790614d721 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -55,9 +55,9 @@ from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name, random_string diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 833ffec3de..ee5469d5a8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,8 +44,8 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import StateMap +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 33ffef521b..f1d2c71c91 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -58,13 +58,13 @@ from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2b31ce54bb..26d79c6e62 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -31,12 +31,12 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.cancellation import cancellable if TYPE_CHECKING: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index af7bebee80..c801a93b5b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -33,8 +33,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 4a4ad0f492..d743282f13 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -22,8 +22,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.state import StateFilter from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string if TYPE_CHECKING: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f8cfcaca83..1a7232b276 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -25,10 +25,10 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.types.state import StateFilter from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.cancellation import cancellable diff --git a/synapse/storage/state.py b/synapse/storage/state.py deleted file mode 100644 index 0004d955b4..0000000000 --- a/synapse/storage/state.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import ( - TYPE_CHECKING, - Callable, - Collection, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - TypeVar, -) - -import attr -from frozendict import frozendict - -from synapse.api.constants import EventTypes -from synapse.types import MutableStateMap, StateKey, StateMap - -if TYPE_CHECKING: - from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - - -logger = logging.getLogger(__name__) - -# Used for generic functions below -T = TypeVar("T") - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StateFilter: - """A filter used when querying for state. - - Attributes: - types: Map from type to set of state keys (or None). This specifies - which state_keys for the given type to fetch from the DB. If None - then all events with that type are fetched. If the set is empty - then no events with that type are fetched. - include_others: Whether to fetch events with types that do not - appear in `types`. - """ - - types: "frozendict[str, Optional[FrozenSet[str]]]" - include_others: bool = False - - def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary - if self.include_others: - # this is needed to work around the fact that StateFilter is frozen - object.__setattr__( - self, - "types", - frozendict({k: v for k, v in self.types.items() if v is not None}), - ) - - @staticmethod - def all() -> "StateFilter": - """Returns a filter that fetches everything. - - Returns: - The state filter. - """ - return _ALL_STATE_FILTER - - @staticmethod - def none() -> "StateFilter": - """Returns a filter that fetches nothing. - - Returns: - The new state filter. - """ - return _NONE_STATE_FILTER - - @staticmethod - def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": - """Creates a filter that only fetches the given types - - Args: - types: A list of type and state keys to fetch. A state_key of None - fetches everything for that type - - Returns: - The new state filter. - """ - type_dict: Dict[str, Optional[Set[str]]] = {} - for typ, s in types: - if typ in type_dict: - if type_dict[typ] is None: - continue - - if s is None: - type_dict[typ] = None - continue - - type_dict.setdefault(typ, set()).add(s) # type: ignore - - return StateFilter( - types=frozendict( - (k, frozenset(v) if v is not None else None) - for k, v in type_dict.items() - ) - ) - - @staticmethod - def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": - """Creates a filter that returns all non-member events, plus the member - events for the given users - - Args: - members: Set of user IDs - - Returns: - The new state filter - """ - return StateFilter( - types=frozendict({EventTypes.Member: frozenset(members)}), - include_others=True, - ) - - @staticmethod - def freeze( - types: Mapping[str, Optional[Collection[str]]], include_others: bool - ) -> "StateFilter": - """ - Returns a (frozen) StateFilter with the same contents as the parameters - specified here, which can be made of mutable types. - """ - types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} - for state_types, state_keys in types.items(): - if state_keys is not None: - types_with_frozen_values[state_types] = frozenset(state_keys) - else: - types_with_frozen_values[state_types] = None - - return StateFilter( - frozendict(types_with_frozen_values), include_others=include_others - ) - - def return_expanded(self) -> "StateFilter": - """Creates a new StateFilter where type wild cards have been removed - (except for memberships). The returned filter is a superset of the - current one, i.e. anything that passes the current filter will pass - the returned filter. - - This helps the caching as the DictionaryCache knows if it has *all* the - state, but does not know if it has all of the keys of a particular type, - which makes wildcard lookups expensive unless we have a complete cache. - Hence, if we are doing a wildcard lookup, populate the cache fully so - that we can do an efficient lookup next time. - - Note that since we have two caches, one for membership events and one for - other events, we can be a bit more clever than simply returning - `StateFilter.all()` if `has_wildcards()` is True. - - We return a StateFilter where: - 1. the list of membership events to return is the same - 2. if there is a wildcard that matches non-member events we - return all non-member events - - Returns: - The new state filter. - """ - - if self.is_full(): - # If we're going to return everything then there's nothing to do - return self - - if not self.has_wildcards(): - # If there are no wild cards, there's nothing to do - return self - - if EventTypes.Member in self.types: - get_all_members = self.types[EventTypes.Member] is None - else: - get_all_members = self.include_others - - has_non_member_wildcard = self.include_others or any( - state_keys is None - for t, state_keys in self.types.items() - if t != EventTypes.Member - ) - - if not has_non_member_wildcard: - # If there are no non-member wild cards we can just return ourselves - return self - - if get_all_members: - # We want to return everything. - return StateFilter.all() - elif EventTypes.Member in self.types: - # We want to return all non-members, but only particular - # memberships - return StateFilter( - types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), - include_others=True, - ) - else: - # We want to return all non-members - return _ALL_NON_MEMBER_STATE_FILTER - - def make_sql_filter_clause(self) -> Tuple[str, List[str]]: - """Converts the filter to an SQL clause. - - For example: - - f = StateFilter.from_types([("m.room.create", "")]) - clause, args = f.make_sql_filter_clause() - clause == "(type = ? AND state_key = ?)" - args == ['m.room.create', ''] - - - Returns: - The SQL string (may be empty) and arguments. An empty SQL string is - returned when the filter matches everything (i.e. is "full"). - """ - - where_clause = "" - where_args: List[str] = [] - - if self.is_full(): - return where_clause, where_args - - if not self.include_others and not self.types: - # i.e. this is an empty filter, so we need to return a clause that - # will match nothing - return "1 = 2", [] - - # First we build up a lost of clauses for each type/state_key combo - clauses = [] - for etype, state_keys in self.types.items(): - if state_keys is None: - clauses.append("(type = ?)") - where_args.append(etype) - continue - - for state_key in state_keys: - clauses.append("(type = ? AND state_key = ?)") - where_args.extend((etype, state_key)) - - # This will match anything that appears in `self.types` - where_clause = " OR ".join(clauses) - - # If we want to include stuff that's not in the types dict then we add - # a `OR type NOT IN (...)` clause to the end. - if self.include_others: - if where_clause: - where_clause += " OR " - - where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) - where_args.extend(self.types) - - return where_clause, where_args - - def max_entries_returned(self) -> Optional[int]: - """Returns the maximum number of entries this filter will return if - known, otherwise returns None. - - For example a simple state filter asking for `("m.room.create", "")` - will return 1, whereas the default state filter will return None. - - This is used to bail out early if the right number of entries have been - fetched. - """ - if self.has_wildcards(): - return None - - return len(self.concrete_types()) - - def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: - """Returns the state filtered with by this StateFilter. - - Args: - state: The state map to filter - - Returns: - The filtered state map. - This is a copy, so it's safe to mutate. - """ - if self.is_full(): - return dict(state_dict) - - filtered_state = {} - for k, v in state_dict.items(): - typ, state_key = k - if typ in self.types: - state_keys = self.types[typ] - if state_keys is None or state_key in state_keys: - filtered_state[k] = v - elif self.include_others: - filtered_state[k] = v - - return filtered_state - - def is_full(self) -> bool: - """Whether this filter fetches everything or not - - Returns: - True if the filter fetches everything. - """ - return self.include_others and not self.types - - def has_wildcards(self) -> bool: - """Whether the filter includes wildcards or is attempting to fetch - specific state. - - Returns: - True if the filter includes wildcards. - """ - - return self.include_others or any( - state_keys is None for state_keys in self.types.values() - ) - - def concrete_types(self) -> List[Tuple[str, str]]: - """Returns a list of concrete type/state_keys (i.e. not None) that - will be fetched. This will be a complete list if `has_wildcards` - returns False, but otherwise will be a subset (or even empty). - - Returns: - A list of type/state_keys tuples. - """ - return [ - (t, s) - for t, state_keys in self.types.items() - if state_keys is not None - for s in state_keys - ] - - def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: - """Return the filter split into two: one which assumes it's exclusively - matching against member state, and one which assumes it's matching - against non member state. - - This is useful due to the returned filters giving correct results for - `is_full()`, `has_wildcards()`, etc, when operating against maps that - either exclusively contain member events or only contain non-member - events. (Which is the case when dealing with the member vs non-member - state caches). - - Returns: - The member and non member filters - """ - - if EventTypes.Member in self.types: - state_keys = self.types[EventTypes.Member] - if state_keys is None: - member_filter = StateFilter.all() - else: - member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) - elif self.include_others: - member_filter = StateFilter.all() - else: - member_filter = StateFilter.none() - - non_member_filter = StateFilter( - types=frozendict( - {k: v for k, v in self.types.items() if k != EventTypes.Member} - ), - include_others=self.include_others, - ) - - return member_filter, non_member_filter - - def _decompose_into_four_parts( - self, - ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: - """ - Decomposes this state filter into 4 constituent parts, which can be - thought of as this: - all? - minus_wildcards + plus_wildcards + plus_state_keys - - where - * all represents ALL state - * minus_wildcards represents entire state types to remove - * plus_wildcards represents entire state types to add - * plus_state_keys represents individual state keys to add - - See `recompose_from_four_parts` for the other direction of this - correspondence. - """ - is_all = self.include_others - excluded_types: Set[str] = {t for t in self.types if is_all} - wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} - concrete_keys: Set[StateKey] = set(self.concrete_types()) - - return (is_all, excluded_types), (wildcard_types, concrete_keys) - - @staticmethod - def _recompose_from_four_parts( - all_part: bool, - minus_wildcards: Set[str], - plus_wildcards: Set[str], - plus_state_keys: Set[StateKey], - ) -> "StateFilter": - """ - Recomposes a state filter from 4 parts. - - See `decompose_into_four_parts` (the other direction of this - correspondence) for descriptions on each of the parts. - """ - - # {state type -> set of state keys OR None for wildcard} - # (The same structure as that of a StateFilter.) - new_types: Dict[str, Optional[Set[str]]] = {} - - # if we start with all, insert the excluded statetypes as empty sets - # to prevent them from being included - if all_part: - new_types.update({state_type: set() for state_type in minus_wildcards}) - - # insert the plus wildcards - new_types.update({state_type: None for state_type in plus_wildcards}) - - # insert the specific state keys - for state_type, state_key in plus_state_keys: - if state_type in new_types: - entry = new_types[state_type] - if entry is not None: - entry.add(state_key) - elif not all_part: - # don't insert if the entire type is already included by - # include_others as this would actually shrink the state allowed - # by this filter. - new_types[state_type] = {state_key} - - return StateFilter.freeze(new_types, include_others=all_part) - - def approx_difference(self, other: "StateFilter") -> "StateFilter": - """ - Returns a state filter which represents `self - other`. - - This is useful for determining what state remains to be pulled out of the - database if we want the state included by `self` but already have the state - included by `other`. - - The returned state filter - - MUST include all state events that are included by this filter (`self`) - unless they are included by `other`; - - MUST NOT include state events not included by this filter (`self`); and - - MAY be an over-approximation: the returned state filter - MAY additionally include some state events from `other`. - - This implementation attempts to return the narrowest such state filter. - In the case that `self` contains wildcards for state types where - `other` contains specific state keys, an approximation must be made: - the returned state filter keeps the wildcard, as state filters are not - able to express 'all state keys except some given examples'. - e.g. - StateFilter(m.room.member -> None (wildcard)) - minus - StateFilter(m.room.member -> {'@wombat:example.org'}) - is approximated as - StateFilter(m.room.member -> None (wildcard)) - """ - - # We first transform self and other into an alternative representation: - # - whether or not they include all events to begin with ('all') - # - if so, which event types are excluded? ('excludes') - # - which entire event types to include ('wildcards') - # - which concrete state keys to include ('concrete state keys') - (self_all, self_excludes), ( - self_wildcards, - self_concrete_keys, - ) = self._decompose_into_four_parts() - (other_all, other_excludes), ( - other_wildcards, - other_concrete_keys, - ) = other._decompose_into_four_parts() - - # Start with an estimate of the difference based on self - new_all = self_all - # Wildcards from the other can be added to the exclusion filter - new_excludes = self_excludes | other_wildcards - # We remove wildcards that appeared as wildcards in the other - new_wildcards = self_wildcards - other_wildcards - # We filter out the concrete state keys that appear in the other - # as wildcards or concrete state keys. - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in self_concrete_keys - if state_type not in other_wildcards - } - other_concrete_keys - - if other_all: - if self_all: - # If self starts with all, then we add as wildcards any - # types which appear in the other's exclusion filter (but - # aren't in the self exclusion filter). This is as the other - # filter will return everything BUT the types in its exclusion, so - # we need to add those excluded types that also match the self - # filter as wildcard types in the new filter. - new_wildcards |= other_excludes.difference(self_excludes) - - # If other is an `include_others` then the difference isn't. - new_all = False - # (We have no need for excludes when we don't start with all, as there - # is nothing to exclude.) - new_excludes = set() - - # We also filter out all state types that aren't in the exclusion - # list of the other. - new_wildcards &= other_excludes - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in new_concrete_keys - if state_type in other_excludes - } - - # Transform our newly-constructed state filter from the alternative - # representation back into the normal StateFilter representation. - return StateFilter._recompose_from_four_parts( - new_all, new_excludes, new_wildcards, new_concrete_keys - ) - - def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: - """Check if we need to wait for full state to complete to calculate this state - - If we have a state filter which is completely satisfied even with partial - state, then we don't need to await_full_state before we can return it. - - Args: - is_mine_id: a callable which confirms if a given state_key matches a mxid - of a local user - """ - # if we haven't requested membership events, then it depends on the value of - # 'include_others' - if EventTypes.Member not in self.types: - return self.include_others - - # if we're looking for *all* membership events, then we have to wait - member_state_keys = self.types[EventTypes.Member] - if member_state_keys is None: - return True - - # otherwise, consider whose membership we are looking for. If it's entirely - # local users, then we don't need to wait. - for state_key in member_state_keys: - if not is_mine_id(state_key): - # remote user - return True - - # local users only - return False - - -_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) -_ALL_NON_MEMBER_STATE_FILTER = StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True -) -_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/types.py b/synapse/types.py deleted file mode 100644 index f2d436ddc3..0000000000 --- a/synapse/types.py +++ /dev/null @@ -1,928 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -import re -import string -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - List, - Mapping, - Match, - MutableMapping, - NoReturn, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) - -import attr -from frozendict import frozendict -from signedjson.key import decode_verify_key_bytes -from signedjson.types import VerifyKey -from typing_extensions import Final, TypedDict -from unpaddedbase64 import decode_base64 -from zope.interface import Interface - -from twisted.internet.defer import CancelledError -from twisted.internet.interfaces import ( - IReactorCore, - IReactorPluggableNameResolver, - IReactorSSL, - IReactorTCP, - IReactorThreads, - IReactorTime, -) - -from synapse.api.errors import Codes, SynapseError -from synapse.util.cancellation import cancellable -from synapse.util.stringutils import parse_and_validate_server_name - -if TYPE_CHECKING: - from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore, PurgeEventsStore - from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore - -# Define a state map type from type/state_key to T (usually an event ID or -# event) -T = TypeVar("T") -StateKey = Tuple[str, str] -StateMap = Mapping[StateKey, T] -MutableStateMap = MutableMapping[StateKey, T] - -# JSON types. These could be made stronger, but will do for now. -# A JSON-serialisable dict. -JsonDict = Dict[str, Any] -# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. -# Useful when you have a TypedDict which isn't going to be mutated and you don't want -# to cast to JsonDict everywhere. -JsonMapping = Mapping[str, Any] -# A JSON-serialisable object. -JsonSerializable = object - - -# Note that this seems to require inheriting *directly* from Interface in order -# for mypy-zope to realize it is an interface. -class ISynapseReactor( - IReactorTCP, - IReactorSSL, - IReactorPluggableNameResolver, - IReactorTime, - IReactorCore, - IReactorThreads, - Interface, -): - """The interfaces necessary for Synapse to function.""" - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class Requester: - """ - Represents the user making a request - - Attributes: - user: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request has been shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - """ - - user: "UserID" - access_token_id: Optional[int] - is_guest: bool - shadow_banned: bool - device_id: Optional[str] - app_service: Optional["ApplicationService"] - authenticated_entity: str - - def serialize(self) -> Dict[str, Any]: - """Converts self to a type that can be serialized as JSON, and then - deserialized by `deserialize` - - Returns: - dict - """ - return { - "user_id": self.user.to_string(), - "access_token_id": self.access_token_id, - "is_guest": self.is_guest, - "shadow_banned": self.shadow_banned, - "device_id": self.device_id, - "app_server_id": self.app_service.id if self.app_service else None, - "authenticated_entity": self.authenticated_entity, - } - - @staticmethod - def deserialize( - store: "ApplicationServiceWorkerStore", input: Dict[str, Any] - ) -> "Requester": - """Converts a dict that was produced by `serialize` back into a - Requester. - - Args: - store: Used to convert AS ID to AS object - input: A dict produced by `serialize` - - Returns: - Requester - """ - appservice = None - if input["app_server_id"]: - appservice = store.get_app_service_by_id(input["app_server_id"]) - - return Requester( - user=UserID.from_string(input["user_id"]), - access_token_id=input["access_token_id"], - is_guest=input["is_guest"], - shadow_banned=input["shadow_banned"], - device_id=input["device_id"], - app_service=appservice, - authenticated_entity=input["authenticated_entity"], - ) - - -def create_requester( - user_id: Union[str, "UserID"], - access_token_id: Optional[int] = None, - is_guest: bool = False, - shadow_banned: bool = False, - device_id: Optional[str] = None, - app_service: Optional["ApplicationService"] = None, - authenticated_entity: Optional[str] = None, -) -> Requester: - """ - Create a new ``Requester`` object - - Args: - user_id: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request is shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - - Returns: - Requester - """ - if not isinstance(user_id, UserID): - user_id = UserID.from_string(user_id) - - if authenticated_entity is None: - authenticated_entity = user_id.to_string() - - return Requester( - user_id, - access_token_id, - is_guest, - shadow_banned, - device_id, - app_service, - authenticated_entity, - ) - - -def get_domain_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[idx + 1 :] - - -def get_localpart_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[1:idx] - - -DS = TypeVar("DS", bound="DomainSpecificString") - - -@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) -class DomainSpecificString(metaclass=abc.ABCMeta): - """Common base class among ID/name strings that have a local part and a - domain name, prefixed with a sigil. - - Has the fields: - - 'localpart' : The local part of the name (without the leading sigil) - 'domain' : The domain part of the name - """ - - SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore - - localpart: str - domain: str - - # Because this is a frozen class, it is deeply immutable. - def __copy__(self: DS) -> DS: - return self - - def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: - return self - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0:1] != cls.SIGIL: - raise SynapseError( - 400, - "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - parts = s[1:].split(":", 1) - if len(parts) != 2: - raise SynapseError( - 400, - "Expected %s of the form '%slocalname:domain'" - % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - domain = parts[1] - # This code will need changing if we want to support multiple domain - # names on one HS - return cls(localpart=parts[0], domain=domain) - - def to_string(self) -> str: - """Return a string encoding the fields of the structure object.""" - return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) - - @classmethod - def is_valid(cls: Type[DS], s: str) -> bool: - """Parses the input string and attempts to ensure it is valid.""" - # TODO: this does not reject an empty localpart or an overly-long string. - # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar - try: - obj = cls.from_string(s) - # Apply additional validation to the domain. This is only done - # during is_valid (and not part of from_string) since it is - # possible for invalid data to exist in room-state, etc. - parse_and_validate_server_name(obj.domain) - return True - except Exception: - return False - - __repr__ = to_string - - -@attr.s(slots=True, frozen=True, repr=False) -class UserID(DomainSpecificString): - """Structure representing a user ID.""" - - SIGIL = "@" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomAlias(DomainSpecificString): - """Structure representing a room name.""" - - SIGIL = "#" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomID(DomainSpecificString): - """Structure representing a room id.""" - - SIGIL = "!" - - -@attr.s(slots=True, frozen=True, repr=False) -class EventID(DomainSpecificString): - """Structure representing an event id.""" - - SIGIL = "$" - - -mxid_localpart_allowed_characters = set( - "_-./=" + string.ascii_lowercase + string.digits -) - - -def contains_invalid_mxid_characters(localpart: str) -> bool: - """Check for characters not allowed in an mxid or groupid localpart - - Args: - localpart: the localpart to be checked - - Returns: - True if there are any naughty characters - """ - return any(c not in mxid_localpart_allowed_characters for c in localpart) - - -UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") - -# the following is a pattern which matches '=', and bytes which are not allowed in a mxid -# localpart. -# -# It works by: -# * building a string containing the allowed characters (excluding '=') -# * escaping every special character with a backslash (to stop '-' being interpreted as a -# range operator) -# * wrapping it in a '[^...]' regex -# * converting the whole lot to a 'bytes' sequence, so that we can use it to match -# bytes rather than strings -# -NON_MXID_CHARACTER_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( - "ascii" - ) -) - - -def map_username_to_mxid_localpart( - username: Union[str, bytes], case_sensitive: bool = False -) -> str: - """Map a username onto a string suitable for a MXID - - This follows the algorithm laid out at - https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. - - Args: - username: username to be mapped - case_sensitive: true if TEST and test should be mapped - onto different mxids - - Returns: - string suitable for a mxid localpart - """ - if not isinstance(username, bytes): - username = username.encode("utf-8") - - # first we sort out upper-case characters - if case_sensitive: - - def f1(m: Match[bytes]) -> bytes: - return b"_" + m.group().lower() - - username = UPPER_CASE_PATTERN.sub(f1, username) - else: - username = username.lower() - - # then we sort out non-ascii characters by converting to the hex equivalent. - def f2(m: Match[bytes]) -> bytes: - return b"=%02x" % (m.group()[0],) - - username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) - - # we also do the =-escaping to mxids starting with an underscore. - username = re.sub(b"^_", b"=5f", username) - - # we should now only have ascii bytes left, so can decode back to a string. - return username.decode("ascii") - - -@attr.s(frozen=True, slots=True, order=False) -class RoomStreamToken: - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] ▼ [1] ▼ [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream, events are ordered by - `stream_ordering` (when they arrived at the homeserver). - - When traversing historic events, events are first ordered by their `depth` - (`topological_ordering` in the event graph) and tie-broken by - `stream_ordering` (when the event arrived at the homeserver). - - If you're looking for more info about what a token with all of the - underscores means, ex. - `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring - for `StreamToken` below. - - --- - - Live tokens start with an "s" followed by the `stream_ordering` of the event - that comes before the position of the token. Said another way: - `stream_ordering` uniquely identifies a persisted event. The live token - means "the position just after the event identified by `stream_ordering`". - An example token is: - - s2633508 - - --- - - Historic tokens start with a "t" followed by the `depth` - (`topological_ordering` in the event graph) of the event that comes before - the position of the token, followed by "-", followed by the - `stream_ordering` of the event that comes before the position of the token. - An example token is: - - t426-2633508 - - --- - - There is also a third mode for live tokens where the token starts with "m", - which is sometimes used when using sharded event persisters. In this case - the events stream is considered to be a set of streams (one for each writer) - and the token encodes the vector clock of positions of each writer in their - respective streams. - - The format of the token in such case is an initial integer min position, - followed by the mapping of instance ID to position separated by '.' and '~': - - m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... - - The `min_pos` corresponds to the minimum position all writers have persisted - up to, and then only writers that are ahead of that position need to be - encoded. An example token is: - - m56~2.58~3.59 - - Which corresponds to a set of three (or more writers) where instances 2 and - 3 (these are instance IDs that can be looked up in the DB to fetch the more - commonly used instance names) are at positions 58 and 59 respectively, and - all other instances are at position 56. - - Note: The `RoomStreamToken` cannot have both a topological part and an - instance map. - - --- - - For caching purposes, `RoomStreamToken`s and by extension, all their - attributes, must be hashable. - """ - - topological: Optional[int] = attr.ib( - validator=attr.validators.optional(attr.validators.instance_of(int)), - ) - stream: int = attr.ib(validator=attr.validators.instance_of(int)) - - instance_map: "frozendict[str, int]" = attr.ib( - factory=frozendict, - validator=attr.validators.deep_mapping( - key_validator=attr.validators.instance_of(str), - value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(frozendict), - ), - ) - - def __attrs_post_init__(self) -> None: - """Validates that both `topological` and `instance_map` aren't set.""" - - if self.instance_map and self.topological: - raise ValueError( - "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." - ) - - @classmethod - async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - if string[0] == "t": - parts = string[1:].split("-", 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - if string[0] == "m": - parts = string[1:].split("~") - stream = int(parts[0]) - - instance_map = {} - for part in parts[1:]: - key, value = part.split(".") - instance_id = int(key) - pos = int(value) - - instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] - instance_map[instance_name] = pos - - return cls( - topological=None, - stream=stream, - instance_map=frozendict(instance_map), - ) - except CancelledError: - raise - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": - """Return a new token such that if an event is after both this token and - the other token, then its after the returned token too. - """ - - if self.topological or other.topological: - raise Exception("Can't advance topological tokens") - - max_stream = max(self.stream, other.stream) - - instance_map = { - instance: max( - self.instance_map.get(instance, self.stream), - other.instance_map.get(instance, other.stream), - ) - for instance in set(self.instance_map).union(other.instance_map) - } - - return RoomStreamToken(None, max_stream, frozendict(instance_map)) - - def as_historical_tuple(self) -> Tuple[int, int]: - """Returns a tuple of `(topological, stream)` for historical tokens. - - Raises if not an historical token (i.e. doesn't have a topological part). - """ - if self.topological is None: - raise Exception( - "Cannot call `RoomStreamToken.as_historical_tuple` on live token" - ) - - return self.topological, self.stream - - def get_stream_pos_for_instance(self, instance_name: str) -> int: - """Get the stream position that the given writer was at at this token. - - This only makes sense for "live" tokens that may have a vector clock - component, and so asserts that this is a "live" token. - """ - assert self.topological is None - - # If we don't have an entry for the instance we can assume that it was - # at `self.stream`. - return self.instance_map.get(instance_name, self.stream) - - def get_max_stream_pos(self) -> int: - """Get the maximum stream position referenced in this token. - - The corresponding "min" position is, by definition just `self.stream`. - - This is used to handle tokens that have non-empty `instance_map`, and so - reference stream positions after the `self.stream` position. - """ - return max(self.instance_map.values(), default=self.stream) - - async def to_string(self, store: "DataStore") -> str: - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - elif self.instance_map: - entries = [] - for name, pos in self.instance_map.items(): - instance_id = await store.get_id_for_instance(name) - entries.append(f"{instance_id}.{pos}") - - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" - else: - return "s%d" % (self.stream,) - - -class StreamKeyType: - """Known stream types. - - A stream is a list of entities ordered by an incrementing "stream token". - """ - - ROOM: Final = "room_key" - PRESENCE: Final = "presence_key" - TYPING: Final = "typing_key" - RECEIPT: Final = "receipt_key" - ACCOUNT_DATA: Final = "account_data_key" - PUSH_RULES: Final = "push_rules_key" - TO_DEVICE: Final = "to_device_key" - DEVICE_LIST: Final = "device_list_key" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StreamToken: - """A collection of keys joined together by underscores in the following - order and which represent the position in their respective streams. - - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` - 1. `room_key`: `s2633508` which is a `RoomStreamToken` - - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - - See the docstring for `RoomStreamToken` for more details. - 2. `presence_key`: `17` - 3. `typing_key`: `338` - 4. `receipt_key`: `6732159` - 5. `account_data_key`: `1082514` - 6. `push_rules_key`: `541479` - 7. `to_device_key`: `274711` - 8. `device_list_key`: `265584` - 9. `groups_key`: `1` (note that this key is now unused) - - You can see how many of these keys correspond to the various - fields in a "/sync" response: - ```json - { - "next_batch": "s12_4_0_1_1_1_1_4_1", - "presence": { - "events": [] - }, - "device_lists": { - "changed": [] - }, - "rooms": { - "join": { - "!QrZlfIDQLNLdZHqTnt:hs1": { - "timeline": { - "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1", - "limited": false - }, - "state": { - "events": [] - }, - "account_data": { - "events": [] - }, - "ephemeral": { - "events": [] - } - } - } - } - } - ``` - - --- - - For caching purposes, `StreamToken`s and by extension, all their attributes, - must be hashable. - """ - - room_key: RoomStreamToken = attr.ib( - validator=attr.validators.instance_of(RoomStreamToken) - ) - presence_key: int - typing_key: int - receipt_key: int - account_data_key: int - push_rules_key: int - to_device_key: int - device_list_key: int - # Note that the groups key is no longer used and may have bogus values. - groups_key: int - - _SEPARATOR = "_" - START: ClassVar["StreamToken"] - - @classmethod - @cancellable - async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": - """ - Creates a RoomStreamToken from its textual representation. - """ - try: - keys = string.split(cls._SEPARATOR) - while len(keys) < len(attr.fields(cls)): - # i.e. old token from before receipt_key - keys.append("0") - return cls( - await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) - ) - except CancelledError: - raise - except Exception: - raise SynapseError(400, "Invalid stream token") - - async def to_string(self, store: "DataStore") -> str: - return self._SEPARATOR.join( - [ - await self.room_key.to_string(store), - str(self.presence_key), - str(self.typing_key), - str(self.receipt_key), - str(self.account_data_key), - str(self.push_rules_key), - str(self.to_device_key), - str(self.device_list_key), - # Note that the groups key is no longer used, but it is still - # serialized so that there will not be confusion in the future - # if additional tokens are added. - str(self.groups_key), - ] - ) - - @property - def room_stream_id(self) -> int: - return self.room_key.stream - - def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": - """Advance the given key in the token to a new value if and only if the - new value is after the old value. - - :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. - """ - if key == StreamKeyType.ROOM: - new_token = self.copy_and_replace( - StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) - ) - return new_token - - new_token = self.copy_and_replace(key, new_value) - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) - - if old_id < new_id: - return new_token - else: - return self - - def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": - return attr.evolve(self, **{key: new_value}) - - -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class PersistedEventPosition: - """Position of a newly persisted event with instance that persisted it. - - This can be used to test whether the event is persisted before or after a - RoomStreamToken. - """ - - instance_name: str - stream: int - - def persisted_after(self, token: RoomStreamToken) -> bool: - return token.get_stream_pos_for_instance(self.instance_name) < self.stream - - def to_room_stream_token(self) -> RoomStreamToken: - """Converts the position to a room stream token such that events - persisted in the same room after this position will be after the - returned `RoomStreamToken`. - - Note: no guarantees are made about ordering w.r.t. events in other - rooms. - """ - # Doing the naive thing satisfies the desired properties described in - # the docstring. - return RoomStreamToken(None, self.stream) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThirdPartyInstanceID: - appservice_id: Optional[str] - network_id: Optional[str] - - # Deny iteration because it will bite you if you try to create a singleton - # set by: - # users = set(user) - def __iter__(self) -> NoReturn: - raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - - # Because this class is a frozen class, it is deeply immutable. - def __copy__(self) -> "ThirdPartyInstanceID": - return self - - def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": - return self - - @classmethod - def from_string(cls, s: str) -> "ThirdPartyInstanceID": - bits = s.split("|", 2) - if len(bits) != 2: - raise SynapseError(400, "Invalid ID %r" % (s,)) - - return cls(appservice_id=bits[0], network_id=bits[1]) - - def to_string(self) -> str: - return "%s|%s" % (self.appservice_id, self.network_id) - - __str__ = to_string - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ReadReceipt: - """Information about a read-receipt""" - - room_id: str - receipt_type: str - user_id: str - event_ids: List[str] - thread_id: Optional[str] - data: JsonDict - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceListUpdates: - """ - An object containing a diff of information regarding other users' device lists, intended for - a recipient to carry out device list tracking. - - Attributes: - changed: A set of users whose device lists have changed recently. - left: A set of users who the recipient no longer needs to track the device lists of. - Typically when those users no longer share any end-to-end encryption enabled rooms. - """ - - # We need to use a factory here, otherwise `set` is not evaluated at - # object instantiation, but instead at class definition instantiation. - # The latter happening only once, thus always giving you the same sets - # across multiple DeviceListUpdates instances. - # Also see: don't define mutable default arguments. - changed: Set[str] = attr.ib(factory=set) - left: Set[str] = attr.ib(factory=set) - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - -def get_verify_key_from_cross_signing_key( - key_info: Mapping[str, Any] -) -> Tuple[str, VerifyKey]: - """Get the key ID and signedjson verify key from a cross-signing key dict - - Args: - key_info: a cross-signing key dict, which must have a "keys" - property that has exactly one item in it - - Returns: - the key ID and verify key for the cross-signing key - """ - # make sure that a `keys` field is provided - if "keys" not in key_info: - raise ValueError("Invalid key") - keys = key_info["keys"] - # and that it contains exactly one key - if len(keys) == 1: - key_id, key_data = next(iter(keys.items())) - return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) - else: - raise ValueError("Invalid key") - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. - - Attributes: - user_id: ID of the user. - appservice_id: Application service ID that created this user. - consent_server_notice_sent: Version of policy documents the user has been sent. - consent_version: Version of policy documents the user has consented to. - creation_ts: Creation timestamp of the user. - is_admin: True if the user is an admin. - is_deactivated: True if the user has been deactivated. - is_guest: True if the user is a guest user. - is_shadow_banned: True if the user has been shadow-banned. - user_type: User type (None for normal user, 'support' and 'bot' other options). - """ - - user_id: UserID - appservice_id: Optional[int] - consent_server_notice_sent: Optional[str] - consent_version: Optional[str] - user_type: Optional[str] - creation_ts: int - is_admin: bool - is_deactivated: bool - is_guest: bool - is_shadow_banned: bool - - -class UserProfile(TypedDict): - user_id: str - display_name: Optional[str] - avatar_url: Optional[str] - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class RetentionPolicy: - min_lifetime: Optional[int] = None - max_lifetime: Optional[int] = None diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py new file mode 100644 index 0000000000..f2d436ddc3 --- /dev/null +++ b/synapse/types/__init__.py @@ -0,0 +1,928 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import re +import string +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Mapping, + Match, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import attr +from frozendict import frozendict +from signedjson.key import decode_verify_key_bytes +from signedjson.types import VerifyKey +from typing_extensions import Final, TypedDict +from unpaddedbase64 import decode_base64 +from zope.interface import Interface + +from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import ( + IReactorCore, + IReactorPluggableNameResolver, + IReactorSSL, + IReactorTCP, + IReactorThreads, + IReactorTime, +) + +from synapse.api.errors import Codes, SynapseError +from synapse.util.cancellation import cancellable +from synapse.util.stringutils import parse_and_validate_server_name + +if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService + from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] + +# JSON types. These could be made stronger, but will do for now. +# A JSON-serialisable dict. +JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] +# A JSON-serialisable object. +JsonSerializable = object + + +# Note that this seems to require inheriting *directly* from Interface in order +# for mypy-zope to realize it is an interface. +class ISynapseReactor( + IReactorTCP, + IReactorSSL, + IReactorPluggableNameResolver, + IReactorTime, + IReactorCore, + IReactorThreads, + Interface, +): + """The interfaces necessary for Synapse to function.""" + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class Requester: + """ + Represents the user making a request + + Attributes: + user: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + """ + + user: "UserID" + access_token_id: Optional[int] + is_guest: bool + shadow_banned: bool + device_id: Optional[str] + app_service: Optional["ApplicationService"] + authenticated_entity: str + + def serialize(self) -> Dict[str, Any]: + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Returns: + dict + """ + return { + "user_id": self.user.to_string(), + "access_token_id": self.access_token_id, + "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, + "device_id": self.device_id, + "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, + } + + @staticmethod + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": + """Converts a dict that was produced by `serialize` back into a + Requester. + + Args: + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` + + Returns: + Requester + """ + appservice = None + if input["app_server_id"]: + appservice = store.get_app_service_by_id(input["app_server_id"]) + + return Requester( + user=UserID.from_string(input["user_id"]), + access_token_id=input["access_token_id"], + is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], + device_id=input["device_id"], + app_service=appservice, + authenticated_entity=input["authenticated_entity"], + ) + + +def create_requester( + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: bool = False, + shadow_banned: bool = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, +) -> Requester: + """ + Create a new ``Requester`` object + + Args: + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + + return Requester( + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, + ) + + +def get_domain_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[idx + 1 :] + + +def get_localpart_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] + + +DS = TypeVar("DS", bound="DomainSpecificString") + + +@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) +class DomainSpecificString(metaclass=abc.ABCMeta): + """Common base class among ID/name strings that have a local part and a + domain name, prefixed with a sigil. + + Has the fields: + + 'localpart' : The local part of the name (without the leading sigil) + 'domain' : The domain part of the name + """ + + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore + + localpart: str + domain: str + + # Because this is a frozen class, it is deeply immutable. + def __copy__(self: DS) -> DS: + return self + + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: + return self + + @classmethod + def from_string(cls: Type[DS], s: str) -> DS: + """Parse the string given by 's' into a structure object.""" + if len(s) < 1 or s[0:1] != cls.SIGIL: + raise SynapseError( + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + parts = s[1:].split(":", 1) + if len(parts) != 2: + raise SynapseError( + 400, + "Expected %s of the form '%slocalname:domain'" + % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + domain = parts[1] + # This code will need changing if we want to support multiple domain + # names on one HS + return cls(localpart=parts[0], domain=domain) + + def to_string(self) -> str: + """Return a string encoding the fields of the structure object.""" + return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + + @classmethod + def is_valid(cls: Type[DS], s: str) -> bool: + """Parses the input string and attempts to ensure it is valid.""" + # TODO: this does not reject an empty localpart or an overly-long string. + # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar + try: + obj = cls.from_string(s) + # Apply additional validation to the domain. This is only done + # during is_valid (and not part of from_string) since it is + # possible for invalid data to exist in room-state, etc. + parse_and_validate_server_name(obj.domain) + return True + except Exception: + return False + + __repr__ = to_string + + +@attr.s(slots=True, frozen=True, repr=False) +class UserID(DomainSpecificString): + """Structure representing a user ID.""" + + SIGIL = "@" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomAlias(DomainSpecificString): + """Structure representing a room name.""" + + SIGIL = "#" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomID(DomainSpecificString): + """Structure representing a room id.""" + + SIGIL = "!" + + +@attr.s(slots=True, frozen=True, repr=False) +class EventID(DomainSpecificString): + """Structure representing an event id.""" + + SIGIL = "$" + + +mxid_localpart_allowed_characters = set( + "_-./=" + string.ascii_lowercase + string.digits +) + + +def contains_invalid_mxid_characters(localpart: str) -> bool: + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart: the localpart to be checked + + Returns: + True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + +UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") + +# the following is a pattern which matches '=', and bytes which are not allowed in a mxid +# localpart. +# +# It works by: +# * building a string containing the allowed characters (excluding '=') +# * escaping every special character with a backslash (to stop '-' being interpreted as a +# range operator) +# * wrapping it in a '[^...]' regex +# * converting the whole lot to a 'bytes' sequence, so that we can use it to match +# bytes rather than strings +# +NON_MXID_CHARACTER_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( + "ascii" + ) +) + + +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: + """Map a username onto a string suitable for a MXID + + This follows the algorithm laid out at + https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. + + Args: + username: username to be mapped + case_sensitive: true if TEST and test should be mapped + onto different mxids + + Returns: + string suitable for a mxid localpart + """ + if not isinstance(username, bytes): + username = username.encode("utf-8") + + # first we sort out upper-case characters + if case_sensitive: + + def f1(m: Match[bytes]) -> bytes: + return b"_" + m.group().lower() + + username = UPPER_CASE_PATTERN.sub(f1, username) + else: + username = username.lower() + + # then we sort out non-ascii characters by converting to the hex equivalent. + def f2(m: Match[bytes]) -> bytes: + return b"=%02x" % (m.group()[0],) + + username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) + + # we also do the =-escaping to mxids starting with an underscore. + username = re.sub(b"^_", b"=5f", username) + + # we should now only have ascii bytes left, so can decode back to a string. + return username.decode("ascii") + + +@attr.s(frozen=True, slots=True, order=False) +class RoomStreamToken: + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] ▼ [1] ▼ [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream, events are ordered by + `stream_ordering` (when they arrived at the homeserver). + + When traversing historic events, events are first ordered by their `depth` + (`topological_ordering` in the event graph) and tie-broken by + `stream_ordering` (when the event arrived at the homeserver). + + If you're looking for more info about what a token with all of the + underscores means, ex. + `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring + for `StreamToken` below. + + --- + + Live tokens start with an "s" followed by the `stream_ordering` of the event + that comes before the position of the token. Said another way: + `stream_ordering` uniquely identifies a persisted event. The live token + means "the position just after the event identified by `stream_ordering`". + An example token is: + + s2633508 + + --- + + Historic tokens start with a "t" followed by the `depth` + (`topological_ordering` in the event graph) of the event that comes before + the position of the token, followed by "-", followed by the + `stream_ordering` of the event that comes before the position of the token. + An example token is: + + t426-2633508 + + --- + + There is also a third mode for live tokens where the token starts with "m", + which is sometimes used when using sharded event persisters. In this case + the events stream is considered to be a set of streams (one for each writer) + and the token encodes the vector clock of positions of each writer in their + respective streams. + + The format of the token in such case is an initial integer min position, + followed by the mapping of instance ID to position separated by '.' and '~': + + m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... + + The `min_pos` corresponds to the minimum position all writers have persisted + up to, and then only writers that are ahead of that position need to be + encoded. An example token is: + + m56~2.58~3.59 + + Which corresponds to a set of three (or more writers) where instances 2 and + 3 (these are instance IDs that can be looked up in the DB to fetch the more + commonly used instance names) are at positions 58 and 59 respectively, and + all other instances are at position 56. + + Note: The `RoomStreamToken` cannot have both a topological part and an + instance map. + + --- + + For caching purposes, `RoomStreamToken`s and by extension, all their + attributes, must be hashable. + """ + + topological: Optional[int] = attr.ib( + validator=attr.validators.optional(attr.validators.instance_of(int)), + ) + stream: int = attr.ib(validator=attr.validators.instance_of(int)) + + instance_map: "frozendict[str, int]" = attr.ib( + factory=frozendict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(frozendict), + ), + ) + + def __attrs_post_init__(self) -> None: + """Validates that both `topological` and `instance_map` aren't set.""" + + if self.instance_map and self.topological: + raise ValueError( + "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." + ) + + @classmethod + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + if string[0] == "t": + parts = string[1:].split("-", 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] + instance_map[instance_name] = pos + + return cls( + topological=None, + stream=stream, + instance_map=frozendict(instance_map), + ) + except CancelledError: + raise + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, frozendict(instance_map)) + + def as_historical_tuple(self) -> Tuple[int, int]: + """Returns a tuple of `(topological, stream)` for historical tokens. + + Raises if not an historical token (i.e. doesn't have a topological part). + """ + if self.topological is None: + raise Exception( + "Cannot call `RoomStreamToken.as_historical_tuple` on live token" + ) + + return self.topological, self.stream + + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token. + + This only makes sense for "live" tokens that may have a vector clock + component, and so asserts that this is a "live" token. + """ + assert self.topological is None + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + + async def to_string(self, store: "DataStore") -> str: + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + instance_id = await store.get_id_for_instance(name) + entries.append(f"{instance_id}.{pos}") + + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + else: + return "s%d" % (self.stream,) + + +class StreamKeyType: + """Known stream types. + + A stream is a list of entities ordered by an incrementing "stream token". + """ + + ROOM: Final = "room_key" + PRESENCE: Final = "presence_key" + TYPING: Final = "typing_key" + RECEIPT: Final = "receipt_key" + ACCOUNT_DATA: Final = "account_data_key" + PUSH_RULES: Final = "push_rules_key" + TO_DEVICE: Final = "to_device_key" + DEVICE_LIST: Final = "device_list_key" + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StreamToken: + """A collection of keys joined together by underscores in the following + order and which represent the position in their respective streams. + + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + 1. `room_key`: `s2633508` which is a `RoomStreamToken` + - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` + - See the docstring for `RoomStreamToken` for more details. + 2. `presence_key`: `17` + 3. `typing_key`: `338` + 4. `receipt_key`: `6732159` + 5. `account_data_key`: `1082514` + 6. `push_rules_key`: `541479` + 7. `to_device_key`: `274711` + 8. `device_list_key`: `265584` + 9. `groups_key`: `1` (note that this key is now unused) + + You can see how many of these keys correspond to the various + fields in a "/sync" response: + ```json + { + "next_batch": "s12_4_0_1_1_1_1_4_1", + "presence": { + "events": [] + }, + "device_lists": { + "changed": [] + }, + "rooms": { + "join": { + "!QrZlfIDQLNLdZHqTnt:hs1": { + "timeline": { + "events": [], + "prev_batch": "s10_4_0_1_1_1_1_4_1", + "limited": false + }, + "state": { + "events": [] + }, + "account_data": { + "events": [] + }, + "ephemeral": { + "events": [] + } + } + } + } + } + ``` + + --- + + For caching purposes, `StreamToken`s and by extension, all their attributes, + must be hashable. + """ + + room_key: RoomStreamToken = attr.ib( + validator=attr.validators.instance_of(RoomStreamToken) + ) + presence_key: int + typing_key: int + receipt_key: int + account_data_key: int + push_rules_key: int + to_device_key: int + device_list_key: int + # Note that the groups key is no longer used and may have bogus values. + groups_key: int + + _SEPARATOR = "_" + START: ClassVar["StreamToken"] + + @classmethod + @cancellable + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + """ + Creates a RoomStreamToken from its textual representation. + """ + try: + keys = string.split(cls._SEPARATOR) + while len(keys) < len(attr.fields(cls)): + # i.e. old token from before receipt_key + keys.append("0") + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) + except CancelledError: + raise + except Exception: + raise SynapseError(400, "Invalid stream token") + + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. + str(self.groups_key), + ] + ) + + @property + def room_stream_id(self) -> int: + return self.room_key.stream + + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. + """ + if key == StreamKeyType.ROOM: + new_token = self.copy_and_replace( + StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + + if old_id < new_id: + return new_token + else: + return self + + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": + return attr.evolve(self, **{key: new_value}) + + +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name: str + stream: int + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.get_stream_pos_for_instance(self.instance_name) < self.stream + + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarantees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + + # Deny iteration because it will bite you if you try to create a singleton + # set by: + # users = set(user) + def __iter__(self) -> NoReturn: + raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) + + # Because this class is a frozen class, it is deeply immutable. + def __copy__(self) -> "ThirdPartyInstanceID": + return self + + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": + return self + + @classmethod + def from_string(cls, s: str) -> "ThirdPartyInstanceID": + bits = s.split("|", 2) + if len(bits) != 2: + raise SynapseError(400, "Invalid ID %r" % (s,)) + + return cls(appservice_id=bits[0], network_id=bits[1]) + + def to_string(self) -> str: + return "%s|%s" % (self.appservice_id, self.network_id) + + __str__ = to_string + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ReadReceipt: + """Information about a read-receipt""" + + room_id: str + receipt_type: str + user_id: str + event_ids: List[str] + thread_id: Optional[str] + data: JsonDict + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info: a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + the key ID and verify key for the cross-signing key + """ + # make sure that a `keys` field is provided + if "keys" not in key_info: + raise ValueError("Invalid key") + keys = key_info["keys"] + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) + return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RetentionPolicy: + min_lifetime: Optional[int] = None + max_lifetime: Optional[int] = None diff --git a/synapse/types/state.py b/synapse/types/state.py new file mode 100644 index 0000000000..0004d955b4 --- /dev/null +++ b/synapse/types/state.py @@ -0,0 +1,567 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, +) + +import attr +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types import MutableStateMap, StateKey, StateMap + +if TYPE_CHECKING: + from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad + + +logger = logging.getLogger(__name__) + +# Used for generic functions below +T = TypeVar("T") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StateFilter: + """A filter used when querying for state. + + Attributes: + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not + appear in `types`. + """ + + types: "frozendict[str, Optional[FrozenSet[str]]]" + include_others: bool = False + + def __attrs_post_init__(self) -> None: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + # this is needed to work around the fact that StateFilter is frozen + object.__setattr__( + self, + "types", + frozendict({k: v for k, v in self.types.items() if v is not None}), + ) + + @staticmethod + def all() -> "StateFilter": + """Returns a filter that fetches everything. + + Returns: + The state filter. + """ + return _ALL_STATE_FILTER + + @staticmethod + def none() -> "StateFilter": + """Returns a filter that fetches nothing. + + Returns: + The new state filter. + """ + return _NONE_STATE_FILTER + + @staticmethod + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": + """Creates a filter that only fetches the given types + + Args: + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type + + Returns: + The new state filter. + """ + type_dict: Dict[str, Optional[Set[str]]] = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) # type: ignore + + return StateFilter( + types=frozendict( + (k, frozenset(v) if v is not None else None) + for k, v in type_dict.items() + ) + ) + + @staticmethod + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members: Set of user IDs + + Returns: + The new state filter + """ + return StateFilter( + types=frozendict({EventTypes.Member: frozenset(members)}), + include_others=True, + ) + + @staticmethod + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + + def return_expanded(self) -> "StateFilter": + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + The new state filter. + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in self.types.items() + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + elif EventTypes.Member in self.types: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), + include_others=True, + ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER + + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). + """ + + where_clause = "" + where_args: List[str] = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in self.types.items(): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self) -> Optional[int]: + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: + """Returns the state filtered with by this StateFilter. + + Args: + state: The state map to filter + + Returns: + The filtered state map. + This is a copy, so it's safe to mutate. + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in state_dict.items(): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self) -> bool: + """Whether this filter fetches everything or not + + Returns: + True if the filter fetches everything. + """ + return self.include_others and not self.types + + def has_wildcards(self) -> bool: + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + True if the filter includes wildcards. + """ + + return self.include_others or any( + state_keys is None for state_keys in self.types.values() + ) + + def concrete_types(self) -> List[Tuple[str, str]]: + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + A list of type/state_keys tuples. + """ + return [ + (t, s) + for t, state_keys in self.types.items() + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types=frozendict( + {k: v for k, v in self.types.items() if k != EventTypes.Member} + ), + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + + def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: + """Check if we need to wait for full state to complete to calculate this state + + If we have a state filter which is completely satisfied even with partial + state, then we don't need to await_full_state before we can return it. + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + # if we haven't requested membership events, then it depends on the value of + # 'include_others' + if EventTypes.Member not in self.types: + return self.include_others + + # if we're looking for *all* membership events, then we have to wait + member_state_keys = self.types[EventTypes.Member] + if member_state_keys is None: + return True + + # otherwise, consider whose membership we are looking for. If it's entirely + # local users, then we don't need to wait. + for state_key in member_state_keys: + if not is_mine_id(state_key): + # remote user + return True + + # local users only + return False + + +_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) +_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/visibility.py b/synapse/visibility.py index b443857571..e442de3173 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -26,8 +26,8 @@ from synapse.events.utils import prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import Clock logger = logging.getLogger(__name__) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d4e6d4236c..a433e70870 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -22,8 +22,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.server import HomeServer -from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, StateMap, UserID +from synapse.types.state import StateFilter from synapse.util import Clock from tests.unittest import HomeserverTestCase, TestCase -- cgit 1.5.1 From 7982891794e26cabe18448f4e0ec0d301f13d186 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 12 Dec 2022 18:13:43 +0000 Subject: Fix missing cache invalidation in application service code (#14670) #11915 introduced the `@cached` `is_interested_in_room` method in Synapse 1.55.0, which depends upon `get_aliases_for_room`. Add a missing cache invalidation callback so that the `is_interested_in_room` cache is invalidated when `get_aliases_for_room` is invalidated. #13787 made `get_rooms_for_user` `@cached`. Add a missing cache invalidation callback so that the `is_interested_in_presence` cache is invalidated when `get_rooms_for_user` is invalidated. Signed-off-by: Sean Quah --- changelog.d/14670.bugfix | 1 + synapse/appservice/__init__.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14670.bugfix (limited to 'synapse') diff --git a/changelog.d/14670.bugfix b/changelog.d/14670.bugfix new file mode 100644 index 0000000000..98398d76cc --- /dev/null +++ b/changelog.d/14670.bugfix @@ -0,0 +1 @@ +Fix bugs introduced in 1.55.0 and 1.69.0 where application services would not be notified of events in the correct rooms, due to stale caches. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index bf4e6c629b..65615f50b8 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -245,7 +245,9 @@ class ApplicationService: return True # likewise with the room's aliases (if it has any) - alias_list = await store.get_aliases_for_room(room_id) + alias_list = await store.get_aliases_for_room( + room_id, on_invalidate=cache_context.invalidate + ) for alias in alias_list: if self.is_room_alias_in_namespace(alias): return True @@ -311,7 +313,9 @@ class ApplicationService: # Find all the rooms the sender is in if self.is_interested_in_user(user_id.to_string()): return True - room_ids = await store.get_rooms_for_user(user_id.to_string()) + room_ids = await store.get_rooms_for_user( + user_id.to_string(), on_invalidate=cache_context.invalidate + ) # Then find out if the appservice is interested in any of those rooms for room_id in room_ids: -- cgit 1.5.1 From 3d87847ecc943c689c4587c5327d744e4a8f92c2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 21:25:07 +0000 Subject: Enable `--warn-redundant-casts` option in mypy (#14671) * Enable `--warn-redundant-casts` option in mypy Doesn't do much but helps me sleep better at night. * Changelog * Fix name of the ignore * Fix one more missed cast Not sure why I didn't see this one locally, maybe I needed a poetry update * Remove old comment Co-authored-by: Patrick Cloke Co-authored-by: Patrick Cloke --- changelog.d/14671.misc | 1 + mypy.ini | 1 + scripts-dev/release.py | 6 ++---- synapse/storage/database.py | 3 ++- synapse/storage/engines/postgres.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14671.misc (limited to 'synapse') diff --git a/changelog.d/14671.misc b/changelog.d/14671.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/14671.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index a4a1e4511a..727536df50 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,6 +12,7 @@ local_partial_types = True no_implicit_optional = True disallow_untyped_defs = True strict_equality = True +warn_redundant_casts = True files = docker/, diff --git a/scripts-dev/release.py b/scripts-dev/release.py index bf47b6c713..6974fd7895 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -27,7 +27,7 @@ import time import urllib.request from os import path from tempfile import TemporaryDirectory -from typing import Any, List, Optional, cast +from typing import Any, List, Optional import attr import click @@ -174,9 +174,7 @@ def _prepare() -> None: click.get_current_context().abort() # Switch to the release branch. - # Cast safety: parse() won't return a version.LegacyVersion from our - # version string format. - parsed_new_version = cast(version.Version, version.parse(new_version)) + parsed_new_version = version.parse(new_version) # We assume for debian changelogs that we only do RCs or full releases. assert not parsed_new_version.is_devrelease diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 55bcb90001..0b29e67b94 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -667,7 +667,8 @@ class DatabasePool: ) # also check variables referenced in func's closure if inspect.isfunction(func): - f = cast(types.FunctionType, func) + # Keep the cast for now---it helps PyCharm to understand what `func` is. + f = cast(types.FunctionType, func) # type: ignore[redundant-cast] if f.__closure__: for i, cell in enumerate(f.__closure__): if inspect.isgenerator(cell.cell_contents): diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 719a517336..f9f562ea45 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -77,7 +77,7 @@ class PostgresEngine( # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them # together. For example, version 8.1.5 will be returned as 80105 - self._version = cast(int, db_conn.server_version) + self._version = db_conn.server_version allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? -- cgit 1.5.1 From e2a1adbf5d11288f2134ced1f84c6ffdd91a9357 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 13 Dec 2022 00:54:46 +0000 Subject: Allow selecting "prejoin" events by state keys (#14642) * Declare new config * Parse new config * Read new config * Don't use trial/our TestCase where it's not needed Before: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m2.277s user 0m2.186s sys 0m0.083s ``` After: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m0.566s user 0m0.508s sys 0m0.056s ``` * Helper to upsert to event fields without exceeding size limits. * Use helper when adding invite/knock state Now that we allow admins to include events in prejoin room state with arbitrary state keys, be a good Matrix citizen and ensure they don't accidentally create an oversized event. * Changelog * Move StateFilter tests should have done this in #14668 * Add extra methods to StateFilter * Use StateFilter * Ensure test file enforces typed defs; alphabetise * Workaround surprising get_current_state_ids * Whoops, fix mypy --- changelog.d/14642.feature | 1 + docs/usage/configuration/config_documentation.md | 57 ++- mypy.ini | 12 +- synapse/config/_util.py | 3 + synapse/config/api.py | 63 ++- synapse/events/utils.py | 32 +- synapse/handlers/message.py | 29 +- synapse/storage/databases/main/events_worker.py | 33 +- synapse/types/state.py | 18 + tests/config/test_api.py | 145 ++++++ tests/events/test_utils.py | 35 +- tests/storage/test_state.py | 623 +--------------------- tests/types/__init__.py | 0 tests/types/test_state.py | 627 +++++++++++++++++++++++ 14 files changed, 983 insertions(+), 695 deletions(-) create mode 100644 changelog.d/14642.feature create mode 100644 tests/config/test_api.py create mode 100644 tests/types/__init__.py create mode 100644 tests/types/test_state.py (limited to 'synapse') diff --git a/changelog.d/14642.feature b/changelog.d/14642.feature new file mode 100644 index 0000000000..cbc9db10c3 --- /dev/null +++ b/changelog.d/14642.feature @@ -0,0 +1 @@ +Allow selecting "prejoin" events by state keys in addition to event types. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index dc5e5ac597..4d32902fea 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2501,32 +2501,53 @@ Config settings related to the client/server API --- ### `room_prejoin_state` -Controls for the state that is shared with users who receive an invite -to a room. By default, the following state event types are shared with users who -receive invites to the room: -- m.room.join_rules -- m.room.canonical_alias -- m.room.avatar -- m.room.encryption -- m.room.name -- m.room.create -- m.room.topic +This setting controls the state that is shared with users upon receiving an +invite to a room, or in reply to a knock on a room. By default, the following +state events are shared with users: + +- `m.room.join_rules` +- `m.room.canonical_alias` +- `m.room.avatar` +- `m.room.encryption` +- `m.room.name` +- `m.room.create` +- `m.room.topic` To change the default behavior, use the following sub-options: -* `disable_default_event_types`: set to true to disable the above defaults. If this - is enabled, only the event types listed in `additional_event_types` are shared. - Defaults to false. -* `additional_event_types`: Additional state event types to share with users when they are invited - to a room. By default, this list is empty (so only the default event types are shared). +* `disable_default_event_types`: boolean. Set to `true` to disable the above + defaults. If this is enabled, only the event types listed in + `additional_event_types` are shared. Defaults to `false`. +* `additional_event_types`: A list of additional state events to include in the + events to be shared. By default, this list is empty (so only the default event + types are shared). + + Each entry in this list should be either a single string or a list of two + strings. + * A standalone string `t` represents all events with type `t` (i.e. + with no restrictions on state keys). + * A pair of strings `[t, s]` represents a single event with type `t` and + state key `s`. The same type can appear in two entries with different state + keys: in this situation, both state keys are included in prejoin state. Example configuration: ```yaml room_prejoin_state: - disable_default_event_types: true + disable_default_event_types: false additional_event_types: - - org.example.custom.event.type - - m.room.join_rules + # Share all events of type `org.example.custom.event.typeA` + - org.example.custom.event.typeA + # Share only events of type `org.example.custom.event.typeB` whose + # state_key is "foo" + - ["org.example.custom.event.typeB", "foo"] + # Share only events of type `org.example.custom.event.typeC` whose + # state_key is "bar" or "baz" + - ["org.example.custom.event.typeC", "bar"] + - ["org.example.custom.event.typeC", "baz"] ``` + +*Changed in Synapse 1.74:* admins can filter the events in prejoin state based +on their state key. + --- ### `track_puppeted_user_ips` diff --git a/mypy.ini b/mypy.ini index 727536df50..37acf589c9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -89,6 +89,12 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.config.test_api] +disallow_untyped_defs = True + +[mypy-tests.federation.transport.test_client] +disallow_untyped_defs = True + [mypy-tests.handlers.test_sso] disallow_untyped_defs = True @@ -101,7 +107,7 @@ disallow_untyped_defs = True [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True -[mypy-tests.test_server] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.state.test_profile] @@ -110,10 +116,10 @@ disallow_untyped_defs = True [mypy-tests.storage.*] disallow_untyped_defs = True -[mypy-tests.rest.*] +[mypy-tests.test_server] disallow_untyped_defs = True -[mypy-tests.federation.transport.test_client] +[mypy-tests.types.*] disallow_untyped_defs = True [mypy-tests.util.caches.*] diff --git a/synapse/config/_util.py b/synapse/config/_util.py index 3edb4b7106..d3a4b484ab 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -33,6 +33,9 @@ def validate_config( config: the configuration value to be validated config_path: the path within the config file. This will be used as a basis for the error message. + + Raises: + ConfigError, if validation fails. """ try: jsonschema.validate(config, json_schema) diff --git a/synapse/config/api.py b/synapse/config/api.py index e46728e73f..27d50d118f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -13,12 +13,13 @@ # limitations under the License. import logging -from typing import Any, Iterable +from typing import Any, Iterable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError from synapse.config._util import validate_config from synapse.types import JsonDict +from synapse.types.state import StateFilter logger = logging.getLogger(__name__) @@ -26,16 +27,20 @@ logger = logging.getLogger(__name__) class ApiConfig(Config): section = "api" + room_prejoin_state: StateFilter + track_puppetted_users_ips: bool + def read_config(self, config: JsonDict, **kwargs: Any) -> None: validate_config(_MAIN_SCHEMA, config, ()) - self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + self.room_prejoin_state = StateFilter.from_types( + self._get_prejoin_state_entries(config) + ) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) - def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: - """Get the event types to include in the prejoin state - - Parses the config and returns an iterable of the event types to be included. - """ + def _get_prejoin_state_entries( + self, config: JsonDict + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get the event types and state keys to include in the prejoin state.""" room_prejoin_state_config = config.get("room_prejoin_state") or {} # backwards-compatibility support for room_invite_state_types @@ -50,33 +55,39 @@ class ApiConfig(Config): logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) - yield from config["room_invite_state_types"] + for event_type in config["room_invite_state_types"]: + yield event_type, None return if not room_prejoin_state_config.get("disable_default_event_types"): - yield from _DEFAULT_PREJOIN_STATE_TYPES + yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS - yield from room_prejoin_state_config.get("additional_event_types", []) + for entry in room_prejoin_state_config.get("additional_event_types", []): + if isinstance(entry, str): + yield entry, None + else: + yield entry _ROOM_INVITE_STATE_TYPES_WARNING = """\ WARNING: The 'room_invite_state_types' configuration setting is now deprecated, and replaced with 'room_prejoin_state'. New features may not work correctly -unless 'room_invite_state_types' is removed. See the sample configuration file for -details of 'room_prejoin_state'. +unless 'room_invite_state_types' is removed. See the config documentation at + https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state +for details of 'room_prejoin_state'. -------------------------------------------------------------------------------- """ -_DEFAULT_PREJOIN_STATE_TYPES = [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, +_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [ + (EventTypes.JoinRules, ""), + (EventTypes.CanonicalAlias, ""), + (EventTypes.RoomAvatar, ""), + (EventTypes.RoomEncryption, ""), + (EventTypes.Name, ""), # Per MSC1772. - EventTypes.Create, + (EventTypes.Create, ""), # Per MSC3173. - EventTypes.Topic, + (EventTypes.Topic, ""), ] @@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = { "disable_default_event_types": {"type": "boolean"}, "additional_event_types": { "type": "array", - "items": {"type": "string"}, + "items": { + "oneOf": [ + {"type": "string"}, + { + "type": "array", + "items": {"type": "string"}, + "minItems": 2, + "maxItems": 2, + }, + ], + }, }, }, }, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 71853caad8..13fa93afb8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -28,8 +28,14 @@ from typing import ( ) import attr +from canonicaljson import encode_canonical_json -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + MAX_PDU_SIZE, + EventContentFields, + EventTypes, + RelationTypes, +) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict @@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None: elif not isinstance(value, (bool, str)) and value is not None: # Other potential JSON values (bool, None, str) are safe. raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) + + +def maybe_upsert_event_field( + event: EventBase, container: JsonDict, key: str, value: object +) -> bool: + """Upsert an event field, but only if this doesn't make the event too large. + + Returns true iff the upsert took place. + """ + if key in container: + old_value: object = container[key] + container[key] = value + # NB: here and below, we assume that passing a non-None `time_now` argument to + # get_pdu_json doesn't increase the size of the encoded result. + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + container[key] = old_value + else: + container[key] = value + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + del container[key] + + return upsert_okay diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6e90ef259..845f683358 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext +from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -1739,12 +1740,15 @@ class EventCreationHandler: if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, + maybe_upsert_event_field( + event, + event.unsigned, + "invite_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, + ), ) invitee = UserID.from_string(event.state_key) @@ -1762,11 +1766,14 @@ class EventCreationHandler: event.signatures.update(returned_invite.signatures) if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + maybe_upsert_event_field( + event, + event.unsigned, + "knock_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + ), ) if event.type == EventTypes.Redaction: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 01e935edef..318fd7dc71 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -16,11 +16,11 @@ import logging import threading import weakref from enum import Enum, auto +from itertools import chain from typing import ( TYPE_CHECKING, Any, Collection, - Container, Dict, Iterable, List, @@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import ( ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: Container[str], + state_keys_to_include: StateFilter, membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore): Args: context: The event context to retrieve state of the room from. - state_types_to_include: The type of state events to include. + state_keys_to_include: The state events to include, for each event type. membership_user_id: An optional user ID to include the stripped membership state events of. This is useful when generating the stripped state of a room for invites. We want to send membership events of the inviter, so that the @@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore): Returns: A list of dictionaries, each representing a stripped state event from the room. """ - current_state_ids = await context.get_current_state_ids() + if membership_user_id: + types = chain( + state_keys_to_include.to_types(), + [(EventTypes.Member, membership_user_id)], + ) + filter = StateFilter.from_types(types) + else: + filter = state_keys_to_include + selected_state_ids = await context.get_current_state_ids(filter) # We know this event is not an outlier, so this must be # non-None. - assert current_state_ids is not None - - # The state to include - state_to_include_ids = [ - e_id - for k, e_id in current_state_ids.items() - if k[0] in state_types_to_include - or (membership_user_id and k == (EventTypes.Member, membership_user_id)) - ] + assert selected_state_ids is not None + + # Confusingly, get_current_state_events may return events that are discarded by + # the filter, if they're in context._state_delta_due_to_event. Strip these away. + selected_state_ids = filter.filter_state(selected_state_ids) - state_to_include = await self.get_events(state_to_include_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [ { diff --git a/synapse/types/state.py b/synapse/types/state.py index 0004d955b4..743a4f9217 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -118,6 +118,15 @@ class StateFilter: ) ) + def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: + """The inverse to `from_types`.""" + for (event_type, state_keys) in self.types.items(): + if state_keys is None: + yield event_type, None + else: + for state_key in state_keys: + yield event_type, state_key + @staticmethod def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member @@ -343,6 +352,15 @@ class StateFilter: for s in state_keys ] + def wildcard_types(self) -> List[str]: + """Returns a list of event types which require us to fetch all state keys. + This will be empty unless `has_wildcards` returns True. + + Returns: + A list of event types. + """ + return [t for t, state_keys in self.types.items() if state_keys is None] + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching diff --git a/tests/config/test_api.py b/tests/config/test_api.py new file mode 100644 index 0000000000..6773c9a277 --- /dev/null +++ b/tests/config/test_api.py @@ -0,0 +1,145 @@ +from unittest import TestCase as StdlibTestCase + +import yaml + +from synapse.config import ConfigError +from synapse.config.api import ApiConfig +from synapse.types.state import StateFilter + +DEFAULT_PREJOIN_STATE_PAIRS = { + ("m.room.join_rules", ""), + ("m.room.canonical_alias", ""), + ("m.room.avatar", ""), + ("m.room.encryption", ""), + ("m.room.name", ""), + ("m.room.create", ""), + ("m.room.topic", ""), +} + + +class TestRoomPrejoinState(StdlibTestCase): + def read_config(self, source: str) -> ApiConfig: + config = ApiConfig() + config.read_config(yaml.safe_load(source)) + return config + + def test_no_prejoin_state(self) -> None: + config = self.read_config("foo: bar") + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS + ) + + def test_disable_default_event_types(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + """ + ) + self.assertEqual(config.room_prejoin_state, StateFilter.none()) + + def test_event_without_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + def test_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + """ + ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar")}, + ) + + def test_repeated_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - [foo, baz] + """ + ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar"), ("foo", "baz")}, + ) + + def test_no_specific_state_key_overrides_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - foo + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + - [foo, bar] + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + def test_bad_event_type_entry_raises(self) -> None: + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a, b, c] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [true, 1.23] + """ + ) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index b1c47efac7..a79256846f 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest as stdlib_unittest + from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( SerializeEventConfig, copy_and_fixup_power_levels_contents, + maybe_upsert_event_field, prune_event, serialize_event, ) from synapse.util.frozenutils import freeze -from tests import unittest - def MockEvent(**kwargs): if "event_id" not in kwargs: @@ -34,7 +35,31 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) -class PruneEventTestCase(unittest.TestCase): +class TestMaybeUpsertEventField(stdlib_unittest.TestCase): + def test_update_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + success = maybe_upsert_event_field(event, event.unsigned, "key", "value") + self.assertTrue(success) + self.assertEqual(event.unsigned["key"], "value") + + def test_update_not_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertNotIn("key", event.unsigned) + + def test_update_not_okay_leaves_original_value(self) -> None: + event = make_event_from_dict( + {"event_id": "$1234", "unsigned": {"key": "value"}} + ) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertEqual(event.unsigned["key"], "value") + + +class PruneEventTestCase(stdlib_unittest.TestCase): def run_test(self, evdict, matchdict, **kwargs): """ Asserts that a new event constructed with `evdict` will look like @@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase): ) -class SerializeEventTestCase(unittest.TestCase): +class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev, fields): return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) @@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase): ) -class CopyPowerLevelsContentTestCase(unittest.TestCase): +class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: self.test_content = { "ban": 50, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a433e70870..bad7f0bc60 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID from synapse.types.state import StateFilter from synapse.util import Clock -from tests.unittest import HomeserverTestCase, TestCase +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) @@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - - -class StateFilterDifferenceTestCase(TestCase): - def assert_difference( - self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter - ) -> None: - self.assertEqual( - minuend.approx_difference(subtrahend), - expected, - f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", - ) - - def test_state_filter_difference_no_include_other_minus_no_include_other( - self, - ) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b do not have the - include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Create: None}, include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.CanonicalAlias: {""}}, - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only a has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Create: None, - EventTypes.Member: set(), - EventTypes.CanonicalAlias: set(), - }, - include_others=True, - ), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - # This also shows that the resultant state filter is normalised. - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - StateFilter(types=frozendict(), include_others=True), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter( - types=frozendict(), - include_others=True, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.CanonicalAlias: {""}, - EventTypes.Member: set(), - }, - include_others=True, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - def test_state_filter_difference_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b have the include_others - flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - EventTypes.Create: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only b has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=True, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_simple_cases(self) -> None: - """ - Tests some very simple cases of the StateFilter approx_difference, - that are not explicitly tested by the more in-depth tests. - """ - - self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) - - self.assert_difference( - StateFilter.all(), - StateFilter.none(), - StateFilter.all(), - ) - - -class StateFilterTestCase(TestCase): - def test_return_expanded(self) -> None: - """ - Tests the behaviour of the return_expanded() function that expands - StateFilters to include more state types (for the sake of cache hit rate). - """ - - self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) - - self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) - - # Concrete-only state filters stay the same - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ), - ) - - # Concrete-only state filters stay the same - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - {"some.other.state.type": {""}}, include_others=False - ).return_expanded(), - StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), - ) - - # Concrete-only state filters stay the same - # (Case: member-only filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ), - ) - - # Wildcard member-only state filters stay the same - self.assertEqual( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, - include_others=True, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - "yet.another.state.type": {"wombat"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/types/test_state.py b/tests/types/test_state.py new file mode 100644 index 0000000000..eb809f9fb7 --- /dev/null +++ b/tests/types/test_state.py @@ -0,0 +1,627 @@ +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types.state import StateFilter + +from tests.unittest import TestCase + + +class StateFilterDifferenceTestCase(TestCase): + def assert_difference( + self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter + ) -> None: + self.assertEqual( + minuend.approx_difference(subtrahend), + expected, + f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", + ) + + def test_state_filter_difference_no_include_other_minus_no_include_other( + self, + ) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b do not have the + include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Create: None}, include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.CanonicalAlias: {""}}, + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only a has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Create: None, + EventTypes.Member: set(), + EventTypes.CanonicalAlias: set(), + }, + include_others=True, + ), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + # This also shows that the resultant state filter is normalised. + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + StateFilter(types=frozendict(), include_others=True), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter( + types=frozendict(), + include_others=True, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.CanonicalAlias: {""}, + EventTypes.Member: set(), + }, + include_others=True, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + def test_state_filter_difference_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b have the include_others + flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + EventTypes.Create: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only b has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=True, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_simple_cases(self) -> None: + """ + Tests some very simple cases of the StateFilter approx_difference, + that are not explicitly tested by the more in-depth tests. + """ + + self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) + + self.assert_difference( + StateFilter.all(), + StateFilter.none(), + StateFilter.all(), + ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self) -> None: + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) -- cgit 1.5.1 From 62ed877433e23ba055cbc69a089c09d03c67681d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 13 Dec 2022 13:19:19 +0000 Subject: Improve validation of field size limits in events. (#14664) --- changelog.d/14664.bugfix | 1 + stubs/synapse/synapse_rust/push.pyi | 2 +- synapse/api/constants.py | 1 + synapse/api/errors.py | 11 ++++- synapse/api/room_versions.py | 32 +++++++------- synapse/event_auth.py | 76 +++++++++++++++++++++++++++++--- synapse/handlers/federation_event.py | 20 +++++++++ synapse/push/bulk_push_rule_evaluator.py | 6 +-- 8 files changed, 119 insertions(+), 30 deletions(-) create mode 100644 changelog.d/14664.bugfix (limited to 'synapse') diff --git a/changelog.d/14664.bugfix b/changelog.d/14664.bugfix new file mode 100644 index 0000000000..a15df9a89d --- /dev/null +++ b/changelog.d/14664.bugfix @@ -0,0 +1 @@ +Improve validation of field size limits in events. \ No newline at end of file diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index a6a586a0b5..dab5d4aff7 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -45,7 +45,7 @@ class PushRuleEvaluator: notification_power_levels: Mapping[str, int], related_events_flattened: Mapping[str, Mapping[str, str]], related_event_match_enabled: bool, - room_version_feature_flags: list[str], + room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, ): ... def run( diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 89723d24fa..6a5e7171da 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -152,6 +152,7 @@ class EduTypes: class RejectedReason: AUTH_ERROR: Final = "auth_error" + OVERSIZED_EVENT: Final = "oversized_event" class RoomCreationPreset: diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 76ef12ed3a..c2c177fd71 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -424,8 +424,17 @@ class ResourceLimitError(SynapseError): class EventSizeError(SynapseError): """An error raised when an event is too big.""" - def __init__(self, msg: str): + def __init__(self, msg: str, unpersistable: bool): + """ + unpersistable: + if True, the PDU must not be persisted, not even as a rejected PDU + when received over federation. + This is notably true when the entire PDU exceeds the size limit for a PDU, + (as opposed to an individual key's size limit being exceeded). + """ + super().__init__(413, msg, Codes.TOO_LARGE) + self.unpersistable = unpersistable class LoginError(SynapseError): diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index ac62011c9f..c397920fe5 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, Optional, Tuple import attr @@ -103,7 +103,7 @@ class RoomVersion: # is not enough to mark it "supported": the push rule evaluator also needs to # support the flag. Unknown flags are ignored by the evaluator, making conditions # fail if used. - msc3931_push_features: List[str] # values from PushRuleRoomFlag + msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag class RoomVersions: @@ -124,7 +124,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V2 = RoomVersion( "2", @@ -143,7 +143,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V3 = RoomVersion( "3", @@ -162,7 +162,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V4 = RoomVersion( "4", @@ -181,7 +181,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V5 = RoomVersion( "5", @@ -200,7 +200,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V6 = RoomVersion( "6", @@ -219,7 +219,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -238,7 +238,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V7 = RoomVersion( "7", @@ -257,7 +257,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V8 = RoomVersion( "8", @@ -276,7 +276,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V9 = RoomVersion( "9", @@ -295,7 +295,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) MSC3787 = RoomVersion( "org.matrix.msc3787", @@ -314,7 +314,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) V10 = RoomVersion( "10", @@ -333,7 +333,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=True, - msc3931_push_features=[], + msc3931_push_features=(), ) MSC2716v4 = RoomVersion( "org.matrix.msc2716v4", @@ -352,7 +352,7 @@ class RoomVersions: msc2716_redactions=True, msc3787_knock_restricted_join_rule=False, msc3667_int_only_power_levels=False, - msc3931_push_features=[], + msc3931_push_features=(), ) MSC1767v10 = RoomVersion( # MSC1767 (Extensible Events) based on room version "10" @@ -372,7 +372,7 @@ class RoomVersions: msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, msc3667_int_only_power_levels=True, - msc3931_push_features=[PushRuleRoomFlag.EXTENSIBLE_EVENTS], + msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,), ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index bab31e33c5..d437b7e5d1 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -52,6 +52,7 @@ from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersion, + RoomVersions, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id @@ -341,19 +342,80 @@ def check_state_dependent_auth_rules( logger.debug("Allowing! %s", event) +# Set of room versions where Synapse did not apply event key size limits +# in bytes, but rather in codepoints. +# In these room versions, we are more lenient with event size validation. +LENIENT_EVENT_BYTE_LIMITS_ROOM_VERSIONS = { + RoomVersions.V1, + RoomVersions.V2, + RoomVersions.V3, + RoomVersions.V4, + RoomVersions.V5, + RoomVersions.V6, + RoomVersions.MSC2176, + RoomVersions.V7, + RoomVersions.V8, + RoomVersions.V9, + RoomVersions.MSC3787, + RoomVersions.V10, + RoomVersions.MSC2716v4, + RoomVersions.MSC1767v10, +} + + def _check_size_limits(event: "EventBase") -> None: + """ + Checks the size limits in a PDU. + + The entire size limit of the PDU is checked first. + Then the size of fields is checked, first in codepoints and then in bytes. + + The codepoint size limits are only for Synapse compatibility. + + Raises: + EventSizeError: + when a size limit has been violated. + + unpersistable=True if Synapse never would have accepted the event and + the PDU must NOT be persisted. + + unpersistable=False if a prior version of Synapse would have accepted the + event and so the PDU must be persisted as rejected to avoid + breaking the room. + """ + + # Whole PDU check + if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: + raise EventSizeError("event too large", unpersistable=True) + + # Codepoint size check: Synapse always enforced these limits, so apply + # them strictly. if len(event.user_id) > 255: - raise EventSizeError("'user_id' too large") + raise EventSizeError("'user_id' too large", unpersistable=True) if len(event.room_id) > 255: - raise EventSizeError("'room_id' too large") + raise EventSizeError("'room_id' too large", unpersistable=True) if event.is_state() and len(event.state_key) > 255: - raise EventSizeError("'state_key' too large") + raise EventSizeError("'state_key' too large", unpersistable=True) if len(event.type) > 255: - raise EventSizeError("'type' too large") + raise EventSizeError("'type' too large", unpersistable=True) if len(event.event_id) > 255: - raise EventSizeError("'event_id' too large") - if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: - raise EventSizeError("event too large") + raise EventSizeError("'event_id' too large", unpersistable=True) + + strict_byte_limits = ( + event.room_version not in LENIENT_EVENT_BYTE_LIMITS_ROOM_VERSIONS + ) + + # Byte size check: if these fail, then be lenient to avoid breaking rooms. + if len(event.user_id.encode("utf-8")) > 255: + raise EventSizeError("'user_id' too large", unpersistable=strict_byte_limits) + if len(event.room_id.encode("utf-8")) > 255: + raise EventSizeError("'room_id' too large", unpersistable=strict_byte_limits) + if event.is_state() and len(event.state_key.encode("utf-8")) > 255: + raise EventSizeError("'state_key' too large", unpersistable=strict_byte_limits) + if len(event.type.encode("utf-8")) > 255: + raise EventSizeError("'type' too large", unpersistable=strict_byte_limits) + if len(event.event_id.encode("utf-8")) > 255: + raise EventSizeError("'event_id' too large", unpersistable=strict_byte_limits) def _check_create(event: "EventBase") -> None: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index d2facdab60..66aca2f864 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -43,6 +43,7 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + EventSizeError, FederationError, FederationPullAttemptBackoffError, HttpResponseException, @@ -1736,6 +1737,15 @@ class FederationEventHandler: except AuthError as e: logger.warning("Rejecting %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR + except EventSizeError as e: + if e.unpersistable: + # This event is completely unpersistable. + raise e + # Otherwise, we are somewhat lenient and just persist the event + # as rejected, for moderate compatibility with older Synapse + # versions. + logger.warning("While validating received event %r: %s", event, e) + context.rejected = RejectedReason.OVERSIZED_EVENT events_and_contexts_to_persist.append((event, context)) @@ -1781,6 +1791,16 @@ class FederationEventHandler: # TODO: use a different rejected reason here? context.rejected = RejectedReason.AUTH_ERROR return + except EventSizeError as e: + if e.unpersistable: + # This event is completely unpersistable. + raise e + # Otherwise, we are somewhat lenient and just persist the event + # as rejected, for moderate compatibility with older Synapse + # versions. + logger.warning("While validating received event %r: %s", event, e) + context.rejected = RejectedReason.OVERSIZED_EVENT + return # next, check that we have all of the event's auth events. # diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 36e5b327ef..f27ba64d53 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -342,10 +342,6 @@ class BulkPushRuleEvaluator: for user_id, level in notification_levels.items(): notification_levels[user_id] = int(level) - room_version_features = event.room_version.msc3931_push_features - if not room_version_features: - room_version_features = [] - evaluator = PushRuleEvaluator( _flatten_dict(event, room_version=event.room_version), room_member_count, @@ -353,7 +349,7 @@ class BulkPushRuleEvaluator: notification_levels, related_events, self._related_event_match_enabled, - room_version_features, + event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) -- cgit 1.5.1 From 2920e540bfd263e33fa25a6f6d642a9f2b965c2f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 13 Dec 2022 08:43:53 -0500 Subject: Use the room type from stats in hierarchy response. (#14263) This avoids pulling additional state information (and events) from the database for each item returned in the hierarchy response. The room type might be out of date until a background update finishes running, the worst impact of this would be spaces being treated as rooms in the hierarchy response. This should self-heal once the background update finishes. --- changelog.d/14263.misc | 1 + synapse/handlers/room_summary.py | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14263.misc (limited to 'synapse') diff --git a/changelog.d/14263.misc b/changelog.d/14263.misc new file mode 100644 index 0000000000..11d9446a4b --- /dev/null +++ b/changelog.d/14263.misc @@ -0,0 +1 @@ +Improve performance of the `/hierarchy` endpoint. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 8d08625237..c6b869c6f4 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, import attr from synapse.api.constants import ( - EventContentFields, EventTypes, HistoryVisibility, JoinRules, @@ -701,13 +700,6 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - current_state_ids = await self._storage_controllers.state.get_current_state_ids( - room_id - ) - create_event = await self._store.get_event( - current_state_ids[(EventTypes.Create, "")] - ) - entry = { "room_id": stats["room_id"], "name": stats["name"], @@ -720,7 +712,7 @@ class RoomSummaryHandler: stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), "guest_can_join": stats["guest_access"] == "can_join", - "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), + "room_type": stats["room_type"], } if self._msc3266_enabled: @@ -730,7 +722,11 @@ class RoomSummaryHandler: # Federation requests need to provide additional information so the # requested server is able to filter the response appropriately. if for_federation: + current_state_ids = ( + await self._storage_controllers.state.get_current_state_ids(room_id) + ) room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( current_state_ids, room_version ): -- cgit 1.5.1 From e512b25cd1618941d165b37f0518ec5765a3b23d Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Wed, 14 Dec 2022 07:02:28 -0500 Subject: Fix #11308 : Remove dependency on jquery on reCAPTCHA page (#14672) --- changelog.d/14672.misc | 1 + synapse/res/templates/recaptcha.html | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14672.misc (limited to 'synapse') diff --git a/changelog.d/14672.misc b/changelog.d/14672.misc new file mode 100644 index 0000000000..b94ebed971 --- /dev/null +++ b/changelog.d/14672.misc @@ -0,0 +1 @@ +Remove dependency on jQuery on reCAPTCHA page. diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html index 8204928cdf..f00992a24b 100644 --- a/synapse/res/templates/recaptcha.html +++ b/synapse/res/templates/recaptcha.html @@ -3,11 +3,10 @@ {% block header %} - {% endblock %} -- cgit 1.5.1 From 24a97b3e7144720545df69c321e320c9d35166a6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Dec 2022 09:25:33 -0500 Subject: Delete event_push_summary_unique_index again. (#14669) if a Synapse deployment upgraded (from < 1.62.0 to >= 1.70.0) then it is possible for schema deltas to run before background updates causing drift in the database schema due to: 1. A delta registered a background update to create an index. 2. A delta dropped the above index if it exists (but it yet exist won't since the background job hasn't run). 3. The code assumed the index was dropped. To fix this we: 1. Cancel the background update which could create the index. 2. Drop the index again. 3. Drop a related index which is dropped by the background update. --- changelog.d/14669.bugfix | 1 + .../storage/databases/main/event_push_actions.py | 9 ------ .../schema/main/delta/73/23_fix_thread_index.sql | 33 ++++++++++++++++++++++ 3 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14669.bugfix create mode 100644 synapse/storage/schema/main/delta/73/23_fix_thread_index.sql (limited to 'synapse') diff --git a/changelog.d/14669.bugfix b/changelog.d/14669.bugfix new file mode 100644 index 0000000000..bea316b065 --- /dev/null +++ b/changelog.d/14669.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 which could cause spurious `UNIQUE constraint failed` errors in the `rotate_notifs` background job. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7ebe34f773..3a0c370fde 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -274,15 +274,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._clear_old_push_actions_staging, 30 * 60 * 1000 ) - self.db_pool.updates.register_background_index_update( - "event_push_summary_unique_index", - index_name="event_push_summary_unique_index", - table="event_push_summary", - columns=["user_id", "room_id"], - unique=True, - replaces_index="event_push_summary_user_rm", - ) - self.db_pool.updates.register_background_index_update( "event_push_summary_unique_index2", index_name="event_push_summary_unique_index2", diff --git a/synapse/storage/schema/main/delta/73/23_fix_thread_index.sql b/synapse/storage/schema/main/delta/73/23_fix_thread_index.sql new file mode 100644 index 0000000000..ec519ceebf --- /dev/null +++ b/synapse/storage/schema/main/delta/73/23_fix_thread_index.sql @@ -0,0 +1,33 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- If a Synapse deployment made a large jump in versions (from < 1.62.0 to >= 1.70.0) +-- in a single upgrade then it might be possible for the event_push_summary_unique_index +-- to be created in the background from delta 71/02event_push_summary_unique.sql after +-- delta 73/06thread_notifications_thread_id_idx.sql is executed, causing it to +-- not drop the event_push_summary_unique_index index. +-- +-- See https://github.com/matrix-org/synapse/issues/14641 + +-- Stop the index from being scheduled for creation in the background. +DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; + +-- The above background job also replaces another index, so ensure that side-effect +-- is applied. +DROP INDEX IF EXISTS event_push_summary_user_rm; + +-- Fix deployments which ran the 73/06thread_notifications_thread_id_idx.sql delta +-- before the event_push_summary_unique_index background job was run. +DROP INDEX IF EXISTS event_push_summary_unique_index; -- cgit 1.5.1 From fb60cb16fe3cf26fbd947eec926cb4b24b8e9fc7 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 14 Dec 2022 14:47:11 +0000 Subject: Faster remote room joins: stream the un-partial-stating of events over replication. [rei:frrj/streams/unpsr] (#14545) --- changelog.d/14545.misc | 1 + synapse/handlers/federation_event.py | 2 + synapse/replication/tcp/streams/__init__.py | 7 +- synapse/replication/tcp/streams/partial_state.py | 28 +++++++ synapse/storage/databases/main/events_worker.py | 88 ++++++++++++++++++++++ synapse/storage/databases/main/state.py | 34 ++++++--- .../delta/73/22_un_partial_stated_event_stream.sql | 34 +++++++++ ..._un_partial_stated_room_stream_seq.sql.postgres | 20 +++++ 8 files changed, 204 insertions(+), 10 deletions(-) create mode 100644 changelog.d/14545.misc create mode 100644 synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql create mode 100644 synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres (limited to 'synapse') diff --git a/changelog.d/14545.misc b/changelog.d/14545.misc new file mode 100644 index 0000000000..60b6761a51 --- /dev/null +++ b/changelog.d/14545.misc @@ -0,0 +1 @@ +Faster remote room joins: stream the un-partial-stating of events over replication. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 66aca2f864..31df7f55cc 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -610,6 +610,8 @@ class FederationEventHandler: self._state_storage_controller.notify_event_un_partial_stated( event.event_id ) + # Notify that there's a new row in the un_partial_stated_events stream. + self._notifier.notify_replication() @trace async def backfill( diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 8575666d9c..110f10aab9 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -42,7 +42,10 @@ from synapse.replication.tcp.streams._base import ( ) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.federation import FederationStream -from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream +from synapse.replication.tcp.streams.partial_state import ( + UnPartialStatedEventStream, + UnPartialStatedRoomStream, +) STREAMS_MAP = { stream.NAME: stream @@ -63,6 +66,7 @@ STREAMS_MAP = { AccountDataStream, UserSignatureStream, UnPartialStatedRoomStream, + UnPartialStatedEventStream, ) } @@ -83,4 +87,5 @@ __all__ = [ "AccountDataStream", "UserSignatureStream", "UnPartialStatedRoomStream", + "UnPartialStatedEventStream", ] diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py index 18f087ffa2..b5a2ae74b6 100644 --- a/synapse/replication/tcp/streams/partial_state.py +++ b/synapse/replication/tcp/streams/partial_state.py @@ -46,3 +46,31 @@ class UnPartialStatedRoomStream(Stream): current_token_without_instance(store.get_un_partial_stated_rooms_token), store.get_un_partial_stated_rooms_from_stream, ) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UnPartialStatedEventStreamRow: + # ID of the event that has been un-partial-stated. + event_id: str + + # True iff the rejection status of the event changed as a result of being + # un-partial-stated. + rejection_status_changed: bool + + +class UnPartialStatedEventStream(Stream): + """ + Stream to notify about events becoming un-partial-stated. + """ + + NAME = "un_partial_stated_event" + ROW_TYPE = UnPartialStatedEventStreamRow + + def __init__(self, hs: "HomeServer"): + store = hs.get_datastores().main + super().__init__( + hs.get_instance_name(), + # TODO(faster_joins, multiple writers): we need to account for instance names + current_token_without_instance(store.get_un_partial_stated_events_token), + store.get_un_partial_stated_events_from_stream, + ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 318fd7dc71..e19b16064b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -70,6 +70,7 @@ from synapse.storage.database import ( from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, @@ -292,6 +293,93 @@ class EventsWorkerStore(SQLBaseStore): id_column="chain_id", ) + self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator + + if isinstance(database.engine, PostgresEngine): + self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="un_partial_stated_event_stream", + instance_name=hs.get_instance_name(), + tables=[ + ("un_partial_stated_event_stream", "instance_name", "stream_id") + ], + sequence_name="un_partial_stated_event_stream_sequence", + # TODO(faster_joins, multiple writers) Support multiple writers. + writers=["master"], + ) + else: + self._un_partial_stated_events_stream_id_gen = StreamIdGenerator( + db_conn, "un_partial_stated_event_stream", "stream_id" + ) + + def get_un_partial_stated_events_token(self) -> int: + # TODO(faster_joins, multiple writers): This is inappropriate if there are multiple + # writers because workers that don't write often will hold all + # readers up. + return self._un_partial_stated_events_stream_id_gen.get_current_token() + + async def get_un_partial_stated_events_from_stream( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, Tuple[str, bool]]], int, bool]: + """Get updates for the un-partial-stated events replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_un_partial_stated_events_from_stream_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str, bool]]], int, bool]: + sql = """ + SELECT stream_id, event_id, rejection_status_changed + FROM un_partial_stated_event_stream + WHERE ? < stream_id AND stream_id <= ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, instance_name, limit)) + updates = [ + ( + row[0], + ( + row[1], + bool(row[2]), + ), + ) + for row in txn + ] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_un_partial_stated_events_from_stream", + get_un_partial_stated_events_from_stream_txn, + ) + def process_replication_rows( self, stream_name: str, diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index c801a93b5b..f855903c39 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -80,6 +80,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) + self._instance_name: str = hs.get_instance_name() async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room @@ -404,18 +405,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): context: EventContext, ) -> None: """Update the state group for a partial state event""" - await self.db_pool.runInteraction( - "update_state_for_partial_state_event", - self._update_state_for_partial_state_event_txn, - event, - context, - ) + async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id: + await self.db_pool.runInteraction( + "update_state_for_partial_state_event", + self._update_state_for_partial_state_event_txn, + event, + context, + un_partial_state_event_stream_id, + ) def _update_state_for_partial_state_event_txn( self, txn: LoggingTransaction, event: EventBase, context: EventContext, + un_partial_state_event_stream_id: int, ) -> None: # we shouldn't have any outliers here assert not event.internal_metadata.is_outlier() @@ -436,7 +440,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # the event may now be rejected where it was not before, or vice versa, # in which case we need to update the rejected flags. - if bool(context.rejected) != (event.rejected_reason is not None): + rejection_status_changed = bool(context.rejected) != ( + event.rejected_reason is not None + ) + if rejection_status_changed: self.mark_event_rejected_txn(txn, event.event_id, context.rejected) self.db_pool.simple_delete_one_txn( @@ -445,8 +452,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): keyvalues={"event_id": event.event_id}, ) - # TODO(faster_joins): need to do something about workers here - # https://github.com/matrix-org/synapse/issues/12994 txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,)) txn.call_after( self._get_state_group_for_event.prefill, @@ -454,6 +459,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): state_group, ) + self.db_pool.simple_insert_txn( + txn, + "un_partial_stated_event_stream", + { + "stream_id": un_partial_state_event_stream_id, + "instance_name": self._instance_name, + "event_id": event.event_id, + "rejection_status_changed": rejection_status_changed, + }, + ) + class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): diff --git a/synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql b/synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql new file mode 100644 index 0000000000..0e571f78c3 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/22_un_partial_stated_event_stream.sql @@ -0,0 +1,34 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stream for notifying that an event has become un-partial-stated. +CREATE TABLE un_partial_stated_event_stream( + -- Position in the stream + stream_id BIGINT PRIMARY KEY NOT NULL, + + -- Which instance wrote this entry. + instance_name TEXT NOT NULL, + + -- Which event has been un-partial-stated. + event_id TEXT NOT NULL REFERENCES events(event_id) ON DELETE CASCADE, + + -- true iff the `rejected` status of the event changed when it became + -- un-partial-stated. + rejection_status_changed BOOLEAN NOT NULL +); + +-- We want an index here because of the foreign key constraint: +-- upon deleting an event, the database needs to be able to check here. +CREATE UNIQUE INDEX un_partial_stated_event_stream_room_id ON un_partial_stated_event_stream (event_id); diff --git a/synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres b/synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres new file mode 100644 index 0000000000..1ec24702f3 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/23_un_partial_stated_room_stream_seq.sql.postgres @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS un_partial_stated_event_stream_sequence; + +SELECT setval('un_partial_stated_event_stream_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM un_partial_stated_event_stream +)); -- cgit 1.5.1 From 4f4d69042345134c040de137a8e1aa108ff71acb Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 14 Dec 2022 14:52:35 +0000 Subject: Allow `compute_state_after_events` to use partial state (#14676) * Allow `compute_state_after_events` to use partial state if fetching a subset of state that is trusted during a partial join. * Changelog --- changelog.d/14676.misc | 1 + synapse/state/__init__.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14676.misc (limited to 'synapse') diff --git a/changelog.d/14676.misc b/changelog.d/14676.misc new file mode 100644 index 0000000000..8a41df9c64 --- /dev/null +++ b/changelog.d/14676.misc @@ -0,0 +1 @@ +Faster joins: make `computer_state_after_events` consistent with other state-fetching functions that take a `StateFilter`. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index ee5469d5a8..fdfb46ab82 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -202,14 +202,20 @@ class StateHandler: room_id: the room_id containing the given events. event_ids: the events whose state should be fetched and resolved. await_full_state: if `True`, will block if we do not yet have complete state - at the given `event_id`s, regardless of whether `state_filter` is - satisfied by partial state. + at these events and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: the state dict (a mapping from (event_type, state_key) -> event_id) which holds the resolution of the states after the given event IDs. """ logger.debug("calling resolve_state_groups from compute_state_after_events") + if ( + await_full_state + and state_filter + and not state_filter.must_await_full_state(self.hs.is_mine_id) + ): + await_full_state = False ret = await self.resolve_state_groups_for_events( room_id, event_ids, await_full_state ) -- cgit 1.5.1 From 54c012c5a8722725cf104fa6205f253b5b9b0192 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 15 Dec 2022 17:04:23 +0100 Subject: Make `handle_new_client_event` throws `PartialStateConflictError` (#14665) Then adapts calling code to retry when needed so it doesn't 500 to clients. Signed-off-by: Mathieu Velten Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14665.misc | 1 + synapse/handlers/federation.py | 117 +++++++++++++------- synapse/handlers/message.py | 202 ++++++++++++++++++---------------- synapse/handlers/room.py | 95 +++++++++------- synapse/handlers/room_batch.py | 2 + synapse/handlers/room_member.py | 168 +++++++++++++++++----------- synapse/util/caches/response_cache.py | 14 ++- 7 files changed, 360 insertions(+), 239 deletions(-) create mode 100644 changelog.d/14665.misc (limited to 'synapse') diff --git a/changelog.d/14665.misc b/changelog.d/14665.misc new file mode 100644 index 0000000000..2b7c96143d --- /dev/null +++ b/changelog.d/14665.misc @@ -0,0 +1 @@ +Change `handle_new_client_event` signature so that a 429 does not reach clients on `PartialStateConflictError`, and internally retry when needed instead. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b2784d7333..eca75f1108 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1343,32 +1343,53 @@ class FederationHandler: ) EventValidator().validate_builder(builder) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context - ) + # Try several times, it could fail with PartialStateConflictError + # in send_membership_event, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + ( + event, + context, + ) = await self.event_creation_handler.create_new_client_event( + builder=builder + ) - EventValidator().validate_new(event, self.config) + event, context = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, context + ) - # We need to tell the transaction queue to send this out, even - # though the sender isn't a local user. - event.internal_metadata.send_on_behalf_of = self.hs.hostname + EventValidator().validate_new(event, self.config) - try: - validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context(event) - except AuthError as e: - logger.warning("Denying new third party invite %r because %s", event, e) - raise e + # We need to tell the transaction queue to send this out, even + # though the sender isn't a local user. + event.internal_metadata.send_on_behalf_of = self.hs.hostname - await self._check_signature(event, context) + try: + validate_event_for_room_version(event) + await self._event_auth_handler.check_auth_rules_from_context( + event + ) + except AuthError as e: + logger.warning( + "Denying new third party invite %r because %s", event, e + ) + raise e - # We retrieve the room member handler here as to not cause a cyclic dependency - member_handler = self.hs.get_room_member_handler() - await member_handler.send_membership_event(None, event, context) + await self._check_signature(event, context) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + await member_handler.send_membership_event(None, event, context) + + break + except PartialStateConflictError as e: + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} @@ -1400,28 +1421,46 @@ class FederationHandler: room_version_obj, event_dict ) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context - ) + # Try several times, it could fail with PartialStateConflictError + # in send_membership_event, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + ( + event, + context, + ) = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + event, context = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, context + ) - try: - validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context(event) - except AuthError as e: - logger.warning("Denying third party invite %r because %s", event, e) - raise e - await self._check_signature(event, context) + try: + validate_event_for_room_version(event) + await self._event_auth_handler.check_auth_rules_from_context(event) + except AuthError as e: + logger.warning("Denying third party invite %r because %s", event, e) + raise e + await self._check_signature(event, context) + + # We need to tell the transaction queue to send this out, even + # though the sender isn't a local user. + event.internal_metadata.send_on_behalf_of = get_domain_from_id( + event.sender + ) - # We need to tell the transaction queue to send this out, even - # though the sender isn't a local user. - event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender) + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + await member_handler.send_membership_event(None, event, context) - # We retrieve the room member handler here as to not cause a cyclic dependency - member_handler = self.hs.get_room_member_handler() - await member_handler.send_membership_event(None, event, context) + break + except PartialStateConflictError as e: + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass async def add_display_name_to_third_party_invite( self, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 845f683358..88fc51a4c9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -37,7 +37,6 @@ from synapse.api.errors import ( AuthError, Codes, ConsentNotGivenError, - LimitExceededError, NotFoundError, ShadowBanError, SynapseError, @@ -999,60 +998,73 @@ class EventCreationHandler: event.internal_metadata.stream_ordering, ) - event, context = await self.create_event( - requester, - event_dict, - txn_id=txn_id, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - state_event_ids=state_event_ids, - outlier=outlier, - historical=historical, - depth=depth, - ) + # Try several times, it could fail with PartialStateConflictError + # in handle_new_client_event, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + event, context = await self.create_event( + requester, + event_dict, + txn_id=txn_id, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + outlier=outlier, + historical=historical, + depth=depth, + ) - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) - spam_check_result = await self.spam_checker.check_event_for_spam(event) - if spam_check_result != self.spam_checker.NOT_SPAM: - if isinstance(spam_check_result, tuple): - try: - [code, dict] = spam_check_result - raise SynapseError( - 403, - "This message had been rejected as probable spam", - code, - dict, - ) - except ValueError: - logger.error( - "Spam-check module returned invalid error value. Expecting [code, dict], got %s", - spam_check_result, - ) + spam_check_result = await self.spam_checker.check_event_for_spam(event) + if spam_check_result != self.spam_checker.NOT_SPAM: + if isinstance(spam_check_result, tuple): + try: + [code, dict] = spam_check_result + raise SynapseError( + 403, + "This message had been rejected as probable spam", + code, + dict, + ) + except ValueError: + logger.error( + "Spam-check module returned invalid error value. Expecting [code, dict], got %s", + spam_check_result, + ) - raise SynapseError( - 403, - "This message has been rejected as probable spam", - Codes.FORBIDDEN, - ) + raise SynapseError( + 403, + "This message has been rejected as probable spam", + Codes.FORBIDDEN, + ) - # Backwards compatibility: if the return value is not an error code, it - # means the module returned an error message to be included in the - # SynapseError (which is now deprecated). - raise SynapseError( - 403, - spam_check_result, - Codes.FORBIDDEN, + # Backwards compatibility: if the return value is not an error code, it + # means the module returned an error message to be included in the + # SynapseError (which is now deprecated). + raise SynapseError( + 403, + spam_check_result, + Codes.FORBIDDEN, + ) + + ev = await self.handle_new_client_event( + requester=requester, + events_and_context=[(event, context)], + ratelimit=ratelimit, + ignore_shadow_ban=ignore_shadow_ban, ) - ev = await self.handle_new_client_event( - requester=requester, - events_and_context=[(event, context)], - ratelimit=ratelimit, - ignore_shadow_ban=ignore_shadow_ban, - ) + break + except PartialStateConflictError as e: + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass # we know it was persisted, so must have a stream ordering assert ev.internal_metadata.stream_ordering @@ -1356,7 +1368,7 @@ class EventCreationHandler: Raises: ShadowBanError if the requester has been shadow-banned. - SynapseError(503) if attempting to persist a partial state event in + PartialStateConflictError if attempting to persist a partial state event in a room that has been un-partial stated. """ extra_users = extra_users or [] @@ -1418,34 +1430,23 @@ class EventCreationHandler: # We now persist the event (and update the cache in parallel, since we # don't want to block on it). event, context = events_and_context[0] - try: - result, _ = await make_deferred_yieldable( - gather_results( - ( - run_in_background( - self._persist_events, - requester=requester, - events_and_context=events_and_context, - ratelimit=ratelimit, - extra_users=extra_users, - ), - run_in_background( - self.cache_joined_hosts_for_events, events_and_context - ).addErrback( - log_failure, "cache_joined_hosts_for_event failed" - ), + result, _ = await make_deferred_yieldable( + gather_results( + ( + run_in_background( + self._persist_events, + requester=requester, + events_and_context=events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, ), - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) - except PartialStateConflictError as e: - # The event context needs to be recomputed. - # Turn the error into a 429, as a hint to the client to try again. - logger.info( - "Room %s was un-partial stated while persisting client event.", - event.room_id, + run_in_background( + self.cache_joined_hosts_for_events, events_and_context + ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), + ), + consumeErrors=True, ) - raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0) + ).addErrback(unwrapFirstError) return result @@ -2012,26 +2013,39 @@ class EventCreationHandler: for user_id in members: requester = create_requester(user_id, authenticated_entity=self.server_name) try: - event, context = await self.create_event( - requester, - { - "type": EventTypes.Dummy, - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - ) + # Try several times, it could fail with PartialStateConflictError + # in handle_new_client_event, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + event, context = await self.create_event( + requester, + { + "type": EventTypes.Dummy, + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + ) - event.internal_metadata.proactively_send = False + event.internal_metadata.proactively_send = False - # Since this is a dummy-event it is OK if it is sent by a - # shadow-banned user. - await self.handle_new_client_event( - requester, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) + # Since this is a dummy-event it is OK if it is sent by a + # shadow-banned user. + await self.handle_new_client_event( + requester, + events_and_context=[(event, context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + + break + except PartialStateConflictError as e: + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass return True except AuthError: logger.info( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f81241c2b3..572c7b4db3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -62,6 +62,7 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin +from synapse.storage.databases.main.events import PartialStateConflictError from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -207,46 +208,64 @@ class RoomCreationHandler: new_room_id = self._generate_room_id() - # Check whether the user has the power level to carry out the upgrade. - # `check_auth_rules_from_context` will check that they are in the room and have - # the required power level to send the tombstone event. - ( - tombstone_event, - tombstone_context, - ) = await self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, - }, - ) - validate_event_for_room_version(tombstone_event) - await self._event_auth_handler.check_auth_rules_from_context(tombstone_event) + # Try several times, it could fail with PartialStateConflictError + # in _upgrade_room, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + # Check whether the user has the power level to carry out the upgrade. + # `check_auth_rules_from_context` will check that they are in the room and have + # the required power level to send the tombstone event. + ( + tombstone_event, + tombstone_context, + ) = await self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, + }, + }, + ) + validate_event_for_room_version(tombstone_event) + await self._event_auth_handler.check_auth_rules_from_context( + tombstone_event + ) - # Upgrade the room - # - # If this user has sent multiple upgrade requests for the same room - # and one of them is not complete yet, cache the response and - # return it to all subsequent requests - ret = await self._upgrade_response_cache.wrap( - (old_room_id, user_id), - self._upgrade_room, - requester, - old_room_id, - old_room, # args for _upgrade_room - new_room_id, - new_version, - tombstone_event, - tombstone_context, - ) + # Upgrade the room + # + # If this user has sent multiple upgrade requests for the same room + # and one of them is not complete yet, cache the response and + # return it to all subsequent requests + ret = await self._upgrade_response_cache.wrap( + (old_room_id, user_id), + self._upgrade_room, + requester, + old_room_id, + old_room, # args for _upgrade_room + new_room_id, + new_version, + tombstone_event, + tombstone_context, + ) - return ret + return ret + except PartialStateConflictError as e: + # Clean up the cache so we can retry properly + self._upgrade_response_cache.unset((old_room_id, user_id)) + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass + + # This is to satisfy mypy and should never happen + raise PartialStateConflictError() async def _upgrade_room( self, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 411a6fb22f..c73d2adaad 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -375,6 +375,8 @@ class RoomBatchHandler: # Events are sorted by (topological_ordering, stream_ordering) # where topological_ordering is just depth. for (event, context) in reversed(events_to_persist): + # This call can't raise `PartialStateConflictError` since we forbid + # use of the historical batch API during partial state await self.event_creation_handler.handle_new_client_event( await self.create_requester_for_user_id_from_app_service( event.sender, app_service_requester.app_service diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0c39e852a1..d236cc09b5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -34,6 +34,7 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM +from synapse.storage.databases.main.events import PartialStateConflictError from synapse.types import ( JsonDict, Requester, @@ -392,60 +393,81 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): event_pos = await self.store.get_position_for_event(existing_event_id) return existing_event_id, event_pos.stream - event, context = await self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": user_id, - # For backwards compatibility: - "membership": membership, - "origin_server_ts": origin_server_ts, - }, - txn_id=txn_id, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - state_event_ids=state_event_ids, - depth=depth, - require_consent=require_consent, - outlier=outlier, - historical=historical, - ) - - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) - ) + # Try several times, it could fail with PartialStateConflictError, + # in handle_new_client_event, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + event, context = await self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": user_id, + # For backwards compatibility: + "membership": membership, + "origin_server_ts": origin_server_ts, + }, + txn_id=txn_id, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + depth=depth, + require_consent=require_consent, + outlier=outlier, + historical=historical, + ) - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, None)]) + ) - if event.membership == Membership.JOIN: - newly_joined = True - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - newly_joined = prev_member_event.membership != Membership.JOIN - - # Only rate-limit if the user actually joined the room, otherwise we'll end - # up blocking profile updates. - if newly_joined and ratelimit: - await self._join_rate_limiter_local.ratelimit(requester) - await self._join_rate_per_room_limiter.ratelimit( - requester, key=room_id, update=False + prev_member_event_id = prev_state_ids.get( + (EventTypes.Member, user_id), None ) - with opentracing.start_active_span("handle_new_client_event"): - result_event = await self.event_creation_handler.handle_new_client_event( - requester, - events_and_context=[(event, context)], - extra_users=[target], - ratelimit=ratelimit, - ) - if event.membership == Membership.LEAVE: - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - if prev_member_event.membership == Membership.JOIN: - await self._user_left_room(target, room_id) + if event.membership == Membership.JOIN: + newly_joined = True + if prev_member_event_id: + prev_member_event = await self.store.get_event( + prev_member_event_id + ) + newly_joined = prev_member_event.membership != Membership.JOIN + + # Only rate-limit if the user actually joined the room, otherwise we'll end + # up blocking profile updates. + if newly_joined and ratelimit: + await self._join_rate_limiter_local.ratelimit(requester) + await self._join_rate_per_room_limiter.ratelimit( + requester, key=room_id, update=False + ) + with opentracing.start_active_span("handle_new_client_event"): + result_event = ( + await self.event_creation_handler.handle_new_client_event( + requester, + events_and_context=[(event, context)], + extra_users=[target], + ratelimit=ratelimit, + ) + ) + + if event.membership == Membership.LEAVE: + if prev_member_event_id: + prev_member_event = await self.store.get_event( + prev_member_event_id + ) + if prev_member_event.membership == Membership.JOIN: + await self._user_left_room(target, room_id) + + break + except PartialStateConflictError as e: + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass # we know it was persisted, so should have a stream ordering assert result_event.internal_metadata.stream_ordering @@ -1234,6 +1256,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit: Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. + PartialStateConflictError: if attempting to persist a partial state event in + a room that has been un-partial stated. """ target_user = UserID.from_string(event.state_key) room_id = event.room_id @@ -1863,21 +1887,37 @@ class RoomMemberMasterHandler(RoomMemberHandler): list(previous_membership_event.auth_event_ids()) + prev_event_ids ) - event, context = await self.event_creation_handler.create_event( - requester, - event_dict, - txn_id=txn_id, - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - outlier=True, - ) - event.internal_metadata.out_of_band_membership = True + # Try several times, it could fail with PartialStateConflictError + # in handle_new_client_event, cf comment in except block. + max_retries = 5 + for i in range(max_retries): + try: + event, context = await self.event_creation_handler.create_event( + requester, + event_dict, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + outlier=True, + ) + event.internal_metadata.out_of_band_membership = True + + result_event = ( + await self.event_creation_handler.handle_new_client_event( + requester, + events_and_context=[(event, context)], + extra_users=[UserID.from_string(target_user)], + ) + ) + + break + except PartialStateConflictError as e: + # Persisting couldn't happen because the room got un-partial stated + # in the meantime and context needs to be recomputed, so let's do so. + if i == max_retries - 1: + raise e + pass - result_event = await self.event_creation_handler.handle_new_client_event( - requester, - events_and_context=[(event, context)], - extra_users=[UserID.from_string(target_user)], - ) # we know it was persisted, so must have a stream ordering assert result_event.internal_metadata.stream_ordering diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a3eb5f741b..340e5e9145 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -167,12 +167,10 @@ class ResponseCache(Generic[KV]): # the should_cache bit, we leave it in the cache for now and schedule # its removal later. if self.timeout_sec and context.should_cache: - self.clock.call_later( - self.timeout_sec, self._result_cache.pop, key, None - ) + self.clock.call_later(self.timeout_sec, self.unset, key) else: # otherwise, remove the result immediately. - self._result_cache.pop(key, None) + self.unset(key) return r # make sure we do this *after* adding the entry to result_cache, @@ -181,6 +179,14 @@ class ResponseCache(Generic[KV]): result.addBoth(on_complete) return entry + def unset(self, key: KV) -> None: + """Remove the cached value for this key from the cache, if any. + + Args: + key: key used to remove the cached value + """ + self._result_cache.pop(key, None) + async def wrap( self, key: KV, -- cgit 1.5.1 From 652d1669c5a103b1c20478770c4aaf18849c09a3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 16 Dec 2022 06:53:01 -0500 Subject: Add missing type hints to tests.handlers. (#14680) And do not allow untyped defs in tests.handlers. --- changelog.d/14680.misc | 1 + mypy.ini | 5 +- synapse/handlers/auth.py | 2 +- tests/handlers/test_appservice.py | 54 +++++----- tests/handlers/test_cas.py | 2 +- tests/handlers/test_directory.py | 27 ++--- tests/handlers/test_e2e_room_keys.py | 76 ++++++++------ tests/handlers/test_federation.py | 2 +- tests/handlers/test_federation_event.py | 10 +- tests/handlers/test_message.py | 26 +++-- tests/handlers/test_oidc.py | 48 ++++++--- tests/handlers/test_password_providers.py | 144 ++++++++++++------------- tests/handlers/test_presence.py | 100 ++++++++++-------- tests/handlers/test_profile.py | 4 +- tests/handlers/test_receipts.py | 6 +- tests/handlers/test_register.py | 169 +++++++++++++++++------------- tests/handlers/test_room.py | 6 +- tests/handlers/test_room_summary.py | 76 ++++++++------ tests/handlers/test_saml.py | 33 ++++-- tests/handlers/test_send_email.py | 29 +++-- tests/handlers/test_stats.py | 74 +++++++++---- tests/handlers/test_sync.py | 11 +- 22 files changed, 527 insertions(+), 378 deletions(-) create mode 100644 changelog.d/14680.misc (limited to 'synapse') diff --git a/changelog.d/14680.misc b/changelog.d/14680.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14680.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 37acf589c9..1a37414e58 100644 --- a/mypy.ini +++ b/mypy.ini @@ -95,10 +95,7 @@ disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True -[mypy-tests.handlers.test_sso] -disallow_untyped_defs = True - -[mypy-tests.handlers.test_user_directory] +[mypy-tests.handlers.*] disallow_untyped_defs = True [mypy-tests.metrics.test_background_process_metrics] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8b9ef25d29..30f2d46c3c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2031,7 +2031,7 @@ class PasswordAuthProvider: self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters - self._supported_login_types: Dict[str, Iterable[str]] = {} + self._supported_login_types: Dict[str, Tuple[str, ...]] = {} # Mapping from login type to auth checker callbacks self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 57bfbd7734..a7495ab21a 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -31,7 +31,7 @@ from synapse.appservice import ( from synapse.handlers.appservice import ApplicationServicesHandler from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer -from synapse.types import RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken from synapse.util import Clock from synapse.util.stringutils import random_string @@ -44,7 +44,7 @@ from tests.utils import MockClock class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" - def setUp(self): + def setUp(self) -> None: self.mock_store = Mock() self.mock_as_api = Mock() self.mock_scheduler = Mock() @@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler = ApplicationServicesHandler(hs) self.event_source = hs.get_event_sources() - def test_notify_interested_services(self): + def test_notify_interested_services(self) -> None: interested_service = self._mkservice(is_interested_in_event=True) services = [ self._mkservice(is_interested_in_event=False), @@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): interested_service, events=[event] ) - def test_query_user_exists_unknown_user(self): + def test_query_user_exists_unknown_user(self) -> None: user_id = "@someone:anywhere" services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True @@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - def test_query_user_exists_known_user(self): + def test_query_user_exists_known_user(self) -> None: user_id = "@someone:anywhere" services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True @@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): "query_user called when it shouldn't have been.", ) - def test_query_room_alias_exists(self): + def test_query_room_alias_exists(self) -> None: room_alias_str = "#foo:bar" room_alias = Mock() room_alias.to_string.return_value = room_alias_str @@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEqual(result.room_id, room_id) self.assertEqual(result.servers, servers) - def test_get_3pe_protocols_no_appservices(self): + def test_get_3pe_protocols_no_appservices(self) -> None: self.mock_store.get_app_services.return_value = [] response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) @@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_not_called() self.assertEqual(response, {}) - def test_get_3pe_protocols_no_protocols(self): + def test_get_3pe_protocols_no_protocols(self) -> None: service = self._mkservice(False, []) self.mock_store.get_app_services.return_value = [service] response = self.successResultOf( @@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_not_called() self.assertEqual(response, {}) - def test_get_3pe_protocols_protocol_no_response(self): + def test_get_3pe_protocols_protocol_no_response(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) @@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.assertEqual(response, {}) - def test_get_3pe_protocols_select_one_protocol(self): + def test_get_3pe_protocols_select_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( @@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) - def test_get_3pe_protocols_one_protocol(self): + def test_get_3pe_protocols_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( @@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) - def test_get_3pe_protocols_multiple_protocol(self): + def test_get_3pe_protocols_multiple_protocol(self) -> None: service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["other-protocol"]) self.mock_store.get_app_services.return_value = [service_one, service_two] @@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): }, ) - def test_get_3pe_protocols_multiple_info(self): + def test_get_3pe_protocols_multiple_info(self) -> None: service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["my-protocol"]) - async def get_3pe_protocol(service, unusedProtocol): + async def get_3pe_protocol( + service: ApplicationService, protocol: str + ) -> Optional[JsonDict]: if service == service_one: return { "x-protocol-data": 42, @@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): }, ) - def test_notify_interested_services_ephemeral(self): + def test_notify_interested_services_ephemeral(self) -> None: """ Test sending ephemeral events to the appservice handler are scheduled to be pushed out to interested appservices, and that the stream ID is @@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): 580, ) - def test_notify_interested_services_ephemeral_out_of_order(self): + def test_notify_interested_services_ephemeral_out_of_order(self) -> None: """ Test sending out of order ephemeral events to the appservice handler are ignored. @@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events @@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) - def _notify_interested_services(self): + def _notify_interested_services(self) -> None: # This is normally set in `notify_interested_services` but we need to call the # internal async version so the reactor gets pushed to completion. self.hs.get_application_service_handler().current_max += 1 @@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): ) def test_match_interesting_room_members( self, interesting_user: str, should_notify: bool - ): + ) -> None: """ Test to make sure that a interesting user (local or remote) in the room is notified as expected when someone else in the room sends a message. @@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): else: self.send_mock.assert_not_called() - def test_application_services_receive_events_sent_by_interesting_local_user(self): + def test_application_services_receive_events_sent_by_interesting_local_user( + self, + ) -> None: """ Test to make sure that a messages sent from a local user can be interesting and picked up by the appservice. @@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["type"], "m.room.message") self.assertEqual(events[0]["sender"], alice) - def test_sending_read_receipt_batches_to_application_services(self): + def test_sending_read_receipt_batches_to_application_services(self) -> None: """Tests that a large batch of read receipts are sent correctly to interested application services. """ @@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) - def test_application_services_receive_local_to_device(self): + def test_application_services_receive_local_to_device(self) -> None: """ Test that when a user sends a to-device message to another user that is an application service's user namespace, the @@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) - def test_application_services_receive_bursts_of_to_device(self): + def test_application_services_receive_bursts_of_to_device(self) -> None: """ Test that when a user sends >100 to-device messages at once, any interested AS's will receive them in separate transactions. @@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) experimental_feature_enabled: bool, as_supports_txn_extensions: bool, as_should_receive_device_list_updates: bool, - ): + ) -> None: """ Tests that an application service receives notice of changed device lists for a user, when a user changes their device lists. @@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): and a room for the users to talk in. """ - async def preparation(): + async def preparation() -> None: await self._add_otks_for_device(self._sender_user, self._sender_device, 42) await self._add_fallback_key_for_device( self._sender_user, self._sender_device, used=True diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 2b21547d0f..2733719d82 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase): ) -def _mock_request(): +def _mock_request() -> Mock: """Returns a mock which will stand in as a SynapseRequest""" mock = Mock( spec=[ diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 3b72c4c9d0..90aec484c4 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.rest.client import directory, login, room from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, create_requester @@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user) -> None: + def _create_alias(self, user: str) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) return room_alias - def _set_canonical_alias(self, content) -> None: + def _set_canonical_alias(self, content: JsonDict) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def _get_canonical_alias(self): + def _get_canonical_alias(self) -> EventBase: """Get the canonical alias state of the room.""" - return self.get_success( + result = self.get_success( self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) + assert result is not None + return result def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" @@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) - self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + self.assertEqual(data.content["alias"], self.test_alias) + self.assertEqual(data.content["alt_aliases"], [self.test_alias]) # Finally, delete the alias. self.get_success( @@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertNotIn("alias", data["content"]) - self.assertNotIn("alt_aliases", data["content"]) + self.assertNotIn("alias", data.content) + self.assertNotIn("alt_aliases", data.content) def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" @@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data.content["alias"], self.test_alias) self.assertEqual( - data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + data.content["alt_aliases"], [self.test_alias, other_test_alias] ) # Delete the second alias. @@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) data = self._get_canonical_alias() - self.assertEqual(data["content"]["alias"], self.test_alias) - self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + self.assertEqual(data.content["alias"], self.test_alias) + self.assertEqual(data.content["alt_aliases"], [self.test_alias]) class TestCreateAliasACL(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 9b7e7a8e9a..6c0b30de9e 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -17,7 +17,11 @@ import copy from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import SynapseError +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -39,14 +43,14 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(replication_layer=mock.Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_room_keys_handler() self.local_user = "@boris:" + hs.hostname - def test_get_missing_current_version_info(self): + def test_get_missing_current_version_info(self) -> None: """Check that we get a 404 if we ask for info about the current version if there is no version. """ @@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_get_missing_version_info(self): + def test_get_missing_version_info(self) -> None: """Check that we get a 404 if we ask for info about a specific version if it doesn't exist. """ @@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_create_version(self): + def test_create_version(self) -> None: """Check that we can create and then retrieve versions.""" - res = self.get_success( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) ) - self.assertEqual(res, "1") + self.assertEqual(version, "1") # check we can retrieve it as the current version res = self.get_success(self.handler.get_version_info(self.local_user)) @@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) # upload a new one... - res = self.get_success( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) ) - self.assertEqual(res, "2") + self.assertEqual(version, "2") # check we can retrieve it as the current version res = self.get_success(self.handler.get_version_info(self.local_user)) @@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_update_version(self): + def test_update_version(self) -> None: """Check that we can update versions.""" version = self.get_success( self.handler.create_version( @@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_update_missing_version(self): + def test_update_missing_version(self) -> None: """Check that we get a 404 on updating nonexistent versions""" e = self.get_failure( self.handler.update_version( @@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_update_omitted_version(self): + def test_update_omitted_version(self) -> None: """Check that the update succeeds if the version is missing from the body""" version = self.get_success( self.handler.create_version( @@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_update_bad_version(self): + def test_update_bad_version(self) -> None: """Check that we get a 400 if the version in the body doesn't match""" version = self.get_success( self.handler.create_version( @@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 400) - def test_delete_missing_version(self): + def test_delete_missing_version(self) -> None: """Check that we get a 404 on deleting nonexistent versions""" e = self.get_failure( self.handler.delete_version(self.local_user, "1"), SynapseError @@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_delete_missing_current_version(self): + def test_delete_missing_current_version(self) -> None: """Check that we get a 404 on deleting nonexistent current version""" e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) res = e.value.code self.assertEqual(res, 404) - def test_delete_version(self): + def test_delete_version(self) -> None: """Check that we can create and then delete versions.""" - res = self.get_success( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) ) - self.assertEqual(res, "1") + self.assertEqual(version, "1") # check we can delete it self.get_success(self.handler.delete_version(self.local_user, "1")) @@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_get_missing_backup(self): + def test_get_missing_backup(self) -> None: """Check that we get a 404 on querying missing backup""" e = self.get_failure( self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError @@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_get_missing_room_keys(self): + def test_get_missing_room_keys(self) -> None: """Check we get an empty response from an empty backup""" version = self.get_success( self.handler.create_version( @@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest - def test_upload_room_keys_no_versions(self): + def test_upload_room_keys_no_versions(self) -> None: """Check that we get a 404 on uploading keys when no versions are defined""" e = self.get_failure( self.handler.upload_room_keys(self.local_user, "no_version", room_keys), @@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_upload_room_keys_bogus_version(self): + def test_upload_room_keys_bogus_version(self) -> None: """Check that we get a 404 on uploading keys when an nonexistent version is specified """ @@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 404) - def test_upload_room_keys_wrong_version(self): + def test_upload_room_keys_wrong_version(self) -> None: """Check that we get a 403 on uploading keys for an old version""" version = self.get_success( self.handler.create_version( @@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 403) - def test_upload_room_keys_insert(self): + def test_upload_room_keys_insert(self) -> None: """Check that we can insert and retrieve keys for a session""" version = self.get_success( self.handler.create_version( @@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertDictEqual(res, room_keys) - def test_upload_room_keys_merge(self): + def test_upload_room_keys_merge(self) -> None: """Check that we can upload a new room_key for an existing session and have it correctly merged""" version = self.get_success( @@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = self.get_success(self.handler.get_room_keys(self.local_user, version)) + res_keys = self.get_success( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( - res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], + res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) @@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = self.get_success(self.handler.get_room_keys(self.local_user, version)) + res_keys = self.get_success( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( - res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" + res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], + "new", ) # the etag should NOT be equal now, since the key changed @@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = self.get_success(self.handler.get_room_keys(self.local_user, version)) + res_keys = self.get_success( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( - res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" + res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], + "new", ) # the etag should be the same since the session did not change @@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): # TODO: check edge cases as well as the common variations here - def test_delete_room_keys(self): + def test_delete_room_keys(self) -> None: """Check that we can insert and delete keys for a session""" version = self.get_success( self.handler.create_version( diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d00c69c229..cedbb9fafc 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - def create_invite(): + def create_invite() -> EventBase: room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) room_version = self.get_success(self.store.get_room_version(room_id)) return event_from_pdu_json( diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index e448cb1901..70ea4d15d4 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -14,6 +14,8 @@ from typing import Optional from unittest import mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion from synapse.event_auth import ( @@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import event_injection, make_awaitable @@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # mock out the federation transport client self.mock_federation_transport_client = mock.Mock( spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] @@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) else: - async def get_event(destination: str, event_id: str, timeout=None): + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, prev_event.event_id) return {"pdus": [prev_event.get_pdu_json()]} diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 99384837d0..c4727ab917 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -14,12 +14,16 @@ import logging from typing import Tuple +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_event_creation_handler() self._persist_event_storage_controller = ( self.hs.get_storage_controllers().persistence @@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): ) ) - def test_duplicated_txn_id(self): + def test_duplicated_txn_id(self) -> None: """Test that attempting to handle/persist an event with a transaction ID that has already been persisted correctly returns the old event and does *not* produce duplicate messages. @@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # rather than the new one. self.assertEqual(ret_event1.event_id, ret_event4.event_id) - def test_duplicated_txn_id_one_call(self): + def test_duplicated_txn_id_one_call(self) -> None: """Test that we correctly handle duplicates that we try and persist at the same time. """ @@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[0].event_id, events[1].event_id) - def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): + def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events( + self, + ) -> None: """When we set allow_no_prev_events=True, should be able to create a event without any prev_events (only auth_events). """ @@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( self, - ): + ) -> None: """When we set allow_no_prev_events=False, shouldn't be able to create a event without any prev_events even if it has auth_events. Expect an exception to be raised. @@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( self, - ): + ) -> None: """When we set allow_no_prev_events=True, should be able to create a event without any prev_events or auth_events. Expect an exception to be raised. @@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - def test_allow_server_acl(self): + def test_allow_server_acl(self) -> None: """Test that sending an ACL that blocks everyone but ourselves works.""" self.helper.send_state( @@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): expect_code=200, ) - def test_deny_server_acl_block_outselves(self): + def test_deny_server_acl_block_outselves(self) -> None: """Test that sending an ACL that blocks ourselves does not work.""" self.helper.send_state( self.room_id, @@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): expect_code=400, ) - def test_deny_redact_server_acl(self): + def test_deny_redact_server_acl(self) -> None: """Test that attempting to redact an ACL is blocked.""" body = self.helper.send_state( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5955410524..49a1842b5c 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Tuple +from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from synapse.util.stringutils import random_string @@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config try: import authlib # noqa: F401 + from authlib.oidc.core import UserInfo + from authlib.oidc.discovery import OpenIDProviderMetadata + + from synapse.handlers.oidc import Token, UserAttributeDict HAS_OIDC = True except ImportError: @@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = { class TestMappingProvider: @staticmethod - def parse_config(config): - return + def parse_config(config: JsonDict) -> None: + return None - def __init__(self, config): + def __init__(self, config: None): pass - def get_remote_user_id(self, userinfo): + def get_remote_user_id(self, userinfo: "UserInfo") -> str: return userinfo["sub"] - async def map_user_attributes(self, userinfo, token): - return {"localpart": userinfo["username"], "display_name": None} + async def map_user_attributes( + self, userinfo: "UserInfo", token: "Token" + ) -> "UserAttributeDict": + # This is testing not providing the full map. + return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item] # Do not include get_extra_attributes to test backwards compatibility paths. class TestMappingProviderExtra(TestMappingProvider): - async def get_extra_attributes(self, userinfo, token): + async def get_extra_attributes( + self, userinfo: "UserInfo", token: "Token" + ) -> JsonDict: return {"phone": userinfo["phone"]} class TestMappingProviderFailures(TestMappingProvider): - async def map_user_attributes(self, userinfo, token, failures): - return { + # Superclass is testing the legacy interface for map_user_attributes. + async def map_user_attributes( # type: ignore[override] + self, userinfo: "UserInfo", token: "Token", failures: int + ) -> "UserAttributeDict": + return { # type: ignore[typeddict-item] "localpart": userinfo["username"] + (str(failures) if failures else ""), "display_name": None, } @@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.hs_patcher.stop() return super().tearDown() - def reset_mocks(self): + def reset_mocks(self) -> None: """Reset all the Mocks.""" self.fake_server.reset_mocks() self.render_error.reset_mock() self.complete_sso_login.reset_mock() - def metadata_edit(self, values): + def metadata_edit(self, values: dict) -> ContextManager[Mock]: """Modify the result that will be returned by the well-known query""" metadata = self.fake_server.get_metadata() @@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase): session = self._generate_oidc_session_token(state, nonce, client_redirect_url) return _build_callback_request(code, state, session), grant - def assertRenderedError(self, error, error_description=None): + def assertRenderedError( + self, error: str, error_description: Optional[str] = None + ) -> Tuple[Any, ...]: self.render_error.assert_called_once() args = self.render_error.call_args[0] self.assertEqual(args[1], error) @@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase): """Provider metadatas are extensively validated.""" h = self.provider - def force_load_metadata(): - async def force_load(): + def force_load_metadata() -> Awaitable[None]: + async def force_load() -> "OpenIDProviderMetadata": return await h.load_metadata(force=True) return get_awaitable_result(force_load()) @@ -1198,7 +1212,7 @@ def _build_callback_request( state: str, session: str, ip_address: str = "10.0.0.1", -): +) -> Mock: """Builds a fake SynapseRequest to mock the browser callback Returns a Mock object which looks like the SynapseRequest we get from a browser diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 75934b1707..0916de64f5 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -15,12 +15,13 @@ """Tests for the password_auth_provider interface""" from http import HTTPStatus -from typing import Any, Type, Union +from typing import Any, Dict, List, Optional, Type, Union from unittest.mock import Mock import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes +from synapse.handlers.account import AccountHandler from synapse.module_api import ModuleApi from synapse.rest.client import account, devices, login, logout, register from synapse.types import JsonDict, UserID @@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider: """A legacy password_provider which only implements `check_password`.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def check_password(self, *args): + def check_password(self, *args: str) -> Mock: return mock_password_provider.check_password(*args) @@ -58,16 +59,16 @@ class LegacyCustomAuthProvider: """A legacy password_provider which implements a custom login type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def get_supported_login_types(self): + def get_supported_login_types(self) -> Dict[str, List[str]]: return {"test.login_type": ["test_field"]} - def check_auth(self, *args): + def check_auth(self, *args: str) -> Mock: return mock_password_provider.check_auth(*args) @@ -75,15 +76,15 @@ class CustomAuthProvider: """A module which registers password_auth_provider callbacks for a custom login type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, api: ModuleApi): + def __init__(self, config: None, api: ModuleApi): api.register_password_auth_provider_callbacks( auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) - def check_auth(self, *args): + def check_auth(self, *args: Any) -> Mock: return mock_password_provider.check_auth(*args) @@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider: as a custom type.""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, account_handler): + def __init__(self, config: None, account_handler: AccountHandler): pass - def get_supported_login_types(self): + def get_supported_login_types(self) -> Dict[str, List[str]]: return {"m.login.password": ["password"], "test.login_type": ["test_field"]} - def check_auth(self, *args): + def check_auth(self, *args: str) -> Mock: return mock_password_provider.check_auth(*args) @@ -110,10 +111,10 @@ class PasswordCustomAuthProvider: as well as a password login""" @staticmethod - def parse_config(self): + def parse_config(config: JsonDict) -> None: pass - def __init__(self, config, api: ModuleApi): + def __init__(self, config: None, api: ModuleApi): api.register_password_auth_provider_callbacks( auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, @@ -121,10 +122,10 @@ class PasswordCustomAuthProvider: } ) - def check_auth(self, *args): + def check_auth(self, *args: Any) -> Mock: return mock_password_provider.check_auth(*args) - def check_pass(self, *args): + def check_pass(self, *args: str) -> Mock: return mock_password_provider.check_password(*args) @@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): CALLBACK_USERNAME = "get_username_for_registration" CALLBACK_DISPLAYNAME = "get_displayname_for_registration" - def setUp(self): + def setUp(self) -> None: # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() super().setUp() @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_password_only_auth_progiver_login_legacy(self): + def test_password_only_auth_progiver_login_legacy(self) -> None: self.password_only_auth_provider_login_test_body() - def password_only_auth_provider_login_test_body(self): + def password_only_auth_provider_login_test_body(self) -> None: # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth_legacy(self): + def test_password_only_auth_provider_ui_auth_legacy(self) -> None: self.password_only_auth_provider_ui_auth_test_body() - def password_only_auth_provider_ui_auth_test_body(self): + def password_only_auth_provider_ui_auth_test_body(self) -> None: """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_local_user_fallback_login_legacy(self): + def test_local_user_fallback_login_legacy(self) -> None: self.local_user_fallback_login_test_body() - def local_user_fallback_login_test_body(self): + def local_user_fallback_login_test_body(self) -> None: """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual("@localuser:test", channel.json_body["user_id"]) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth_legacy(self): + def test_local_user_fallback_ui_auth_legacy(self) -> None: self.local_user_fallback_ui_auth_test_body() - def local_user_fallback_ui_auth_test_body(self): + def local_user_fallback_ui_auth_test_body(self) -> None: """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login_legacy(self): + def test_no_local_user_fallback_login_legacy(self) -> None: self.no_local_user_fallback_login_test_body() - def no_local_user_fallback_login_test_body(self): + def no_local_user_fallback_login_test_body(self) -> None: """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth_legacy(self): + def test_no_local_user_fallback_ui_auth_legacy(self) -> None: self.no_local_user_fallback_ui_auth_test_body() - def no_local_user_fallback_ui_auth_test_body(self): + def no_local_user_fallback_ui_auth_test_body(self) -> None: """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_auth_disabled_legacy(self): + def test_password_auth_disabled_legacy(self) -> None: self.password_auth_disabled_test_body() - def password_auth_disabled_test_body(self): + def password_auth_disabled_test_body(self) -> None: """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_login_legacy(self): + def test_custom_auth_provider_login_legacy(self) -> None: self.custom_auth_provider_login_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_login(self): + def test_custom_auth_provider_login(self) -> None: self.custom_auth_provider_login_test_body() - def custom_auth_provider_login_test_body(self): + def custom_auth_provider_login_test_body(self) -> None: # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_ui_auth_legacy(self): + def test_custom_auth_provider_ui_auth_legacy(self) -> None: self.custom_auth_provider_ui_auth_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_ui_auth(self): + def test_custom_auth_provider_ui_auth(self) -> None: self.custom_auth_provider_ui_auth_test_body() - def custom_auth_provider_ui_auth_test_body(self): + def custom_auth_provider_ui_auth_test_body(self) -> None: # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) @override_config(legacy_providers_config(LegacyCustomAuthProvider)) - def test_custom_auth_provider_callback_legacy(self): + def test_custom_auth_provider_callback_legacy(self) -> None: self.custom_auth_provider_callback_test_body() @override_config(providers_config(CustomAuthProvider)) - def test_custom_auth_provider_callback(self): + def test_custom_auth_provider_callback(self) -> None: self.custom_auth_provider_callback_test_body() - def custom_auth_provider_callback_test_body(self): + def custom_auth_provider_callback_test_body(self) -> None: callback = Mock(return_value=make_awaitable(None)) mock_password_provider.check_auth.return_value = make_awaitable( @@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_custom_auth_password_disabled_legacy(self): + def test_custom_auth_password_disabled_legacy(self) -> None: self.custom_auth_password_disabled_test_body() @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) - def test_custom_auth_password_disabled(self): + def test_custom_auth_password_disabled(self) -> None: self.custom_auth_password_disabled_test_body() - def custom_auth_password_disabled_test_body(self): + def custom_auth_password_disabled_test_body(self) -> None: """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False, "localdb_enabled": False}, } ) - def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None: self.custom_auth_password_disabled_localdb_enabled_test_body() @override_config( @@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False, "localdb_enabled": False}, } ) - def test_custom_auth_password_disabled_localdb_enabled(self): + def test_custom_auth_password_disabled_localdb_enabled(self) -> None: self.custom_auth_password_disabled_localdb_enabled_test_body() - def custom_auth_password_disabled_localdb_enabled_test_body(self): + def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None: """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_login_legacy(self): + def test_password_custom_auth_password_disabled_login_legacy(self) -> None: self.password_custom_auth_password_disabled_login_test_body() @override_config( @@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_login(self): + def test_password_custom_auth_password_disabled_login(self) -> None: self.password_custom_auth_password_disabled_login_test_body() - def password_custom_auth_password_disabled_login_test_body(self): + def password_custom_auth_password_disabled_login_test_body(self) -> None: """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None: self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( @@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"enabled": False}, } ) - def test_password_custom_auth_password_disabled_ui_auth(self): + def test_password_custom_auth_password_disabled_ui_auth(self) -> None: self.password_custom_auth_password_disabled_ui_auth_test_body() - def password_custom_auth_password_disabled_ui_auth_test_body(self): + def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None: """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, @@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_custom_auth_no_local_user_fallback_legacy(self): + def test_custom_auth_no_local_user_fallback_legacy(self) -> None: self.custom_auth_no_local_user_fallback_test_body() @override_config( @@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "password_config": {"localdb_enabled": False}, } ) - def test_custom_auth_no_local_user_fallback(self): + def test_custom_auth_no_local_user_fallback(self) -> None: self.custom_auth_no_local_user_fallback_test_body() - def custom_auth_no_local_user_fallback_test_body(self): + def custom_auth_no_local_user_fallback_test_body(self) -> None: """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") @@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_on_logged_out(self): + def test_on_logged_out(self) -> None: """Tests that the on_logged_out callback is called when the user logs out.""" self.register_user("rin", "password") tok = self.login("rin", "password") self.called = False - async def on_logged_out(user_id, device_id, access_token): + async def on_logged_out( + user_id: str, device_id: Optional[str], access_token: str + ) -> None: self.called = True on_logged_out = Mock(side_effect=on_logged_out) @@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): on_logged_out.assert_called_once() self.assertTrue(self.called) - def test_username(self): + def test_username(self) -> None: """Tests that the get_username_for_registration callback can define the username of a user when registering. """ @@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mxid = channel.json_body["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") - def test_username_uia(self): + def test_username_uia(self) -> None: """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ @@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Set some email configuration so the test doesn't fail because of its absence. @override_config({"email": {"notif_from": "noreply@test"}}) - def test_3pid_allowed(self): + def test_3pid_allowed(self) -> None: """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind the 3PID. Also checks that the module is passed a boolean indicating whether the @@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) - def test_displayname(self): + def test_displayname(self) -> None: """Tests that the get_displayname_for_registration callback can define the display name of a user when registering. """ @@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(display_name, username + "-foo") - def test_displayname_uia(self): + def test_displayname_uia(self) -> None: """Tests that the get_displayname_for_registration callback is only called at the end of the UIA flow. """ @@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Check that the callback has been called. m.assert_called_once() - def _test_3pid_allowed(self, username: str, registration: bool): + def _test_3pid_allowed(self, username: str, registration: bool) -> None: """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): client is trying to register. """ - async def callback(uia_results, params): + async def callback(uia_results: JsonDict, params: JsonDict) -> str: self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" @@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def _send_password_login(self, user: str, password: str) -> FakeChannel: return self._send_login(type="m.login.password", user=user, password=password) - def _send_login(self, type, user, **params) -> FakeChannel: - params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) + def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel: + params = {"identifier": {"type": "m.id.user", "user": user}, "type": type} + params.update(extra_params) channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel - def _start_delete_device_session(self, access_token, device_id) -> str: + def _start_delete_device_session(self, access_token: str, device_id: str) -> str: """Make an initial delete device request, and return the UI Auth session ID""" channel = self._delete_device(access_token, device_id) self.assertEqual(channel.code, 401) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 584e7b8971..19f5322317 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, cast from unittest.mock import Mock, call from parameterized import parameterized from signedjson.key import generate_signing_key +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -35,7 +37,9 @@ from synapse.handlers.presence import ( ) from synapse.rest import admin from synapse.rest.client import room -from synapse.types import UserID, get_domain_from_id +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main - def test_offline_to_online(self): + def test_offline_to_online(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_online(self): + def test_online_to_online(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_online_last_active_noop(self): + def test_online_to_online_last_active_noop(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_online_last_active(self): + def test_online_to_online_last_active(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_remote_ping_timer(self): + def test_remote_ping_timer(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_online_to_offline(self): + def test_online_to_offline(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertEqual(wheel_timer.insert.call_count, 0) - def test_online_to_idle(self): + def test_online_to_idle(self) -> None: wheel_timer = Mock() user_id = "@foo:bar" now = 5000000 @@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): any_order=True, ) - def test_persisting_presence_updates(self): + def test_persisting_presence_updates(self) -> None: """Tests that the latest presence state for each user is persisted correctly""" # Create some test users and presence states for them presence_states = [] @@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.get_success(self.store.update_presence(presence_states)) # Check that each update is present in the database - db_presence_states = self.get_success( + db_presence_states_raw = self.get_success( self.store.get_all_presence_updates( instance_name="master", last_id=0, @@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) # Extract presence update user ID and state information into lists of tuples - db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] + db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]] presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] # Compare what we put into the storage with what we got out. @@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceTimeoutTestCase(unittest.TestCase): """Tests different timers and that the timer does not change `status_msg` of user.""" - def test_idle_timer(self): + def test_idle_timer(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.status_msg, status_msg) - def test_busy_no_idle(self): + def test_busy_no_idle(self) -> None: """ Tests that a user setting their presence to busy but idling doesn't turn their presence state into unavailable. @@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.status_msg, status_msg) - def test_sync_timeout(self): + def test_sync_timeout(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) - def test_sync_online(self): + def test_sync_online(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.status_msg, status_msg) - def test_federation_ping(self): + def test_federation_ping(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEqual(state, new_state) - def test_no_timeout(self): + def test_no_timeout(self) -> None: user_id = "@foo:bar" now = 5000000 @@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNone(new_state) - def test_federation_timeout(self): + def test_federation_timeout(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) - def test_last_active(self): + def test_last_active(self) -> None: user_id = "@foo:bar" status_msg = "I'm here!" now = 5000000 @@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase): class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() - def test_external_process_timeout(self): + def test_external_process_timeout(self) -> None: """Test that if an external process doesn't update the records for a while we time out their syncing users presence. """ - process_id = 1 + process_id = "1" user_id = "@test:server" # Notify handler that a user is now syncing. @@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ) self.assertEqual(state.state, PresenceState.OFFLINE) - def test_user_goes_offline_by_timeout_status_msg_remain(self): + def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None: """Test that if a user doesn't update the records for a while users presence goes `OFFLINE` because of timeout and `status_msg` remains. """ @@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.status_msg, status_msg) - def test_user_goes_offline_manually_with_no_status_msg(self): + def test_user_goes_offline_manually_with_no_status_msg(self) -> None: """Test that if a user change presence manually to `OFFLINE` and no status is set, that `status_msg` is `None`. """ @@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.status_msg, None) - def test_user_goes_offline_manually_with_status_msg(self): + def test_user_goes_offline_manually_with_status_msg(self) -> None: """Test that if a user change presence manually to `OFFLINE` and a status is set, that `status_msg` appears. """ @@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): user_id, PresenceState.OFFLINE, "And now here." ) - def test_user_reset_online_with_no_status(self): + def test_user_reset_online_with_no_status(self) -> None: """Test that if a user set again the presence manually and no status is set, that `status_msg` is `None`. """ @@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.status_msg, None) - def test_set_presence_with_status_msg_none(self): + def test_set_presence_with_status_msg_none(self) -> None: """Test that if a user set again the presence manually and status is `None`, that `status_msg` is `None`. """ @@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # Mark user as online and `status_msg = None` self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) - def test_set_presence_from_syncing_not_set(self): + def test_set_presence_from_syncing_not_set(self) -> None: """Test that presence is not set by syncing if affect_presence is false""" user_id = "@test:server" status_msg = "I'm here!" @@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # and status message should still be the same self.assertEqual(state.status_msg, status_msg) - def test_set_presence_from_syncing_is_set(self): + def test_set_presence_from_syncing_is_set(self) -> None: """Test that presence is set by syncing if affect_presence is true""" user_id = "@test:server" status_msg = "I'm here!" @@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # we should now be online self.assertEqual(state.state, PresenceState.ONLINE) - def test_set_presence_from_syncing_keeps_status(self): + def test_set_presence_from_syncing_keeps_status(self) -> None: """Test that presence set by syncing retains status message""" user_id = "@test:server" status_msg = "I'm here!" @@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): }, } ) - def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): + def test_set_presence_from_syncing_keeps_busy( + self, test_with_workers: bool + ) -> None: """Test that presence set by syncing doesn't affect busy status Args: @@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def _set_presencestate_with_status_msg( self, user_id: str, state: str, status_msg: Optional[str] - ): + ) -> None: """Set a PresenceState and status_msg and check the result. Args: @@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.instance_name = hs.get_instance_name() self.queue = self.presence_handler.get_federation_queue() - def test_send_and_get(self): + def test_send_and_get(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertFalse(limited) self.assertCountEqual(rows, []) - def test_send_and_get_split(self): + def test_send_and_get_split(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertCountEqual(rows, expected_rows) - def test_clear_queue_all(self): + def test_clear_queue_all(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertCountEqual(rows, expected_rows) - def test_partially_clear_queue(self): + def test_partially_clear_queue(self) -> None: state1 = UserPresenceState.default("@user1:test") state2 = UserPresenceState.default("@user2:test") state3 = UserPresenceState.default("@user3:test") @@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( "server", federation_http_client=None, @@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) return hs - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() # Enable federation sending on the main process. config["federation_sender_instances"] = None return config - def prepare(self, reactor, clock, hs): - self.federation_sender = hs.get_federation_sender() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.federation_sender = cast(Mock, hs.get_federation_sender()) self.event_builder_factory = hs.get_event_builder_factory() self.federation_event_handler = hs.get_federation_event_handler() self.presence_handler = hs.get_presence_handler() @@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # random key to use. self.random_signing_key = generate_signing_key("ver") - def test_remote_joins(self): + def test_remote_joins(self) -> None: # We advance time to something that isn't 0, as we use 0 as a special # value. self.reactor.advance(1000000000000) @@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): destinations={"server3"}, states=[expected_state] ) - def test_remote_gets_presence_when_local_user_joins(self): + def test_remote_gets_presence_when_local_user_joins(self) -> None: # We advance time to something that isn't 0, as we use 0 as a special # value. self.reactor.advance(1000000000000) @@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): destinations={"server2", "server3"}, states=[expected_state] ) - def _add_new_user(self, room_id, user_id): + def _add_new_user(self, room_id: str, user_id: str) -> None: """Add new user to the room by creating an event and poking the federation API.""" hostname = get_domain_from_id(user_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 675aa023ac..7c174782da 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} ) - def test_avatar_constraint_on_local_server_with_port(self): + def test_avatar_constraint_on_local_server_with_port(self) -> None: """Test that avatar metadata is correctly fetched when the media is on a local server and the server has an explicit port. @@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) ) - def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. Args: diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index b55238650c..f60400ff8d 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,14 +15,18 @@ from copy import deepcopy from typing import List +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes, ReceiptTypes +from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest class ReceiptsTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_source = hs.get_event_sources().sources.receipt def test_filters_out_private_receipt(self) -> None: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 765df75d91..b9332d97dc 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Collection, List, Optional, Tuple from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -22,8 +25,18 @@ from synapse.api.errors import ( ResourceLimitError, SynapseError, ) +from synapse.module_api import ModuleApi +from synapse.server import HomeServer from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.types import ( + JsonDict, + Requester, + RoomAlias, + RoomID, + UserID, + create_requester, +) +from synapse.util import Clock from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -33,94 +46,98 @@ from .. import unittest class TestSpamChecker: - def __init__(self, config, api): + def __init__(self, config: None, api: ModuleApi): api.register_spam_checker_callbacks( check_registration_for_spam=self.check_registration_for_spam, ) @staticmethod - def parse_config(config): - return config + def parse_config(config: JsonDict) -> None: + return None async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - auth_provider_id, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str], + ) -> RegistrationBehaviour: pass class DenyAll(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - auth_provider_id, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str], + ) -> RegistrationBehaviour: return RegistrationBehaviour.DENY class BanAll(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - auth_provider_id, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str], + ) -> RegistrationBehaviour: return RegistrationBehaviour.SHADOW_BAN class BanBadIdPUser(TestSpamChecker): async def check_registration_for_spam( - self, email_threepid, username, request_info, auth_provider_id=None - ): + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str] = None, + ) -> RegistrationBehaviour: # Reject any user coming from CAS and whose username contains profanity - if auth_provider_id == "cas" and "flimflob" in username: + if auth_provider_id == "cas" and username and "flimflob" in username: return RegistrationBehaviour.DENY return RegistrationBehaviour.ALLOW class TestLegacyRegistrationSpamChecker: - def __init__(self, config, api): + def __init__(self, config: None, api: ModuleApi): pass async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: pass class LegacyAllowAll(TestLegacyRegistrationSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: return RegistrationBehaviour.ALLOW class LegacyDenyAll(TestLegacyRegistrationSpamChecker): async def check_registration_for_spam( self, - email_threepid, - username, - request_info, - ): + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: return RegistrationBehaviour.DENY class RegistrationTestCase(unittest.HomeserverTestCase): """Tests the RegistrationHandler.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs_config = self.default_config() # some of the tests rely on us having a user consent version @@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastores().main self.lots_of_users = 100 @@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.requester = create_requester("@requester:test") - def test_user_is_created_and_logged_in_if_doesnt_exist(self): + def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None: frank = UserID.from_string("@frank:test") user_id = frank.to_string() requester = create_requester(user_id) @@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertIsInstance(result_token, str) self.assertGreater(len(result_token), 20) - def test_if_user_exists(self): + def test_if_user_exists(self) -> None: store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( @@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(result_token is not None) @override_config({"limit_usage_by_mau": False}) - def test_mau_limits_when_disabled(self): + def test_mau_limits_when_disabled(self) -> None: # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) @override_config({"limit_usage_by_mau": True}) - def test_get_or_create_user_mau_not_blocked(self): + def test_get_or_create_user_mau_not_blocked(self) -> None: self.store.count_monthly_users = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) @@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.get_success(self.get_or_create_user(self.requester, "c", "User")) @override_config({"limit_usage_by_mau": True}) - def test_get_or_create_user_mau_blocked(self): + def test_get_or_create_user_mau_blocked(self) -> None: self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) @override_config({"limit_usage_by_mau": True}) - def test_register_mau_blocked(self): + def test_register_mau_blocked(self) -> None: self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config( {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} ) - def test_auto_join_rooms_for_guests(self): + def test_auto_join_rooms_for_guests(self) -> None: user_id = self.get_success( self.handler.register_user(localpart="jeff", make_guest=True), ) @@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 0) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms(self): + def test_auto_create_auto_join_rooms(self) -> None: room_alias_str = "#room:test" user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 1) @override_config({"auto_join_rooms": []}) - def test_auto_create_auto_join_rooms_with_no_rooms(self): + def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None: frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) @@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 0) @override_config({"auto_join_rooms": ["#room:another"]}) - def test_auto_create_auto_join_where_room_is_another_domain(self): + def test_auto_create_auto_join_where_room_is_another_domain(self) -> None: frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) @@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config( {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} ) - def test_auto_create_auto_join_where_auto_create_is_false(self): + def test_auto_create_auto_join_where_auto_create_is_false(self) -> None: user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): + def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: room_alias_str = "#room:test" self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) @@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.get_failure(directory_handler.get_association(room_alias), SynapseError) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): + def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" self.store.count_real_users = Mock(return_value=make_awaitable(1)) @@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rooms), 1) @override_config({"auto_join_rooms": ["#room:test"]}) - def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): + def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( + self, + ) -> None: self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) @@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "autocreate_auto_join_rooms_federated": False, } ) - def test_auto_create_auto_join_rooms_federated(self): + def test_auto_create_auto_join_rooms_federated(self) -> None: """ Auto-created rooms that are private require an invite to go to the user (instead of directly joining it). @@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config( {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"} ) - def test_auto_join_mxid_localpart(self): + def test_auto_join_mxid_localpart(self) -> None: """ Ensure the user still needs up in the room created by a different user. """ @@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_mxid_localpart": "support", } ) - def test_auto_create_auto_join_room_preset(self): + def test_auto_create_auto_join_room_preset(self) -> None: """ Auto-created rooms that are private require an invite to go to the user (instead of directly joining it). @@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_mxid_localpart": "support", } ) - def test_auto_create_auto_join_room_preset_guest(self): + def test_auto_create_auto_join_room_preset_guest(self) -> None: """ Auto-created rooms that are private require an invite to go to the user (instead of directly joining it). @@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_mxid_localpart": "support", } ) - def test_auto_create_auto_join_room_preset_invalid_permissions(self): + def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None: """ Auto-created rooms that are private require an invite, check that registration doesn't completely break if the inviter doesn't have proper @@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): "auto_join_rooms": ["#room:test"], }, ) - def test_auto_create_auto_join_where_no_consent(self): + def test_auto_create_auto_join_where_no_consent(self) -> None: """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. """ @@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 1) - def test_register_support_user(self): + def test_register_support_user(self) -> None: user_id = self.get_success( self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) ) d = self.store.is_support_user(user_id) self.assertTrue(self.get_success(d)) - def test_register_not_support_user(self): + def test_register_not_support_user(self) -> None: user_id = self.get_success(self.handler.register_user(localpart="user")) d = self.store.is_support_user(user_id) self.assertFalse(self.get_success(d)) - def test_invalid_user_id_length(self): + def test_invalid_user_id_length(self) -> None: invalid_user_id = "x" * 256 self.get_failure( self.handler.register_user(localpart=invalid_user_id), SynapseError @@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_deny(self): + def test_spam_checker_deny(self) -> None: """A spam checker can deny registration, which results in an error.""" self.get_failure(self.handler.register_user(localpart="user"), SynapseError) @@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_legacy_allow(self): + def test_spam_checker_legacy_allow(self) -> None: """Tests that a legacy spam checker implementing the legacy 3-arg version of the check_registration_for_spam callback is correctly called. @@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_legacy_deny(self): + def test_spam_checker_legacy_deny(self) -> None: """Tests that a legacy spam checker implementing the legacy 3-arg version of the check_registration_for_spam callback is correctly called. @@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_shadow_ban(self): + def test_spam_checker_shadow_ban(self) -> None: """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" user_id = self.get_success(self.handler.register_user(localpart="user")) @@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ] } ) - def test_spam_checker_receives_sso_type(self): + def test_spam_checker_receives_sso_type(self) -> None: """Test rejecting registration based on SSO type""" f = self.get_failure( self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), @@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) async def get_or_create_user( - self, requester, localpart, displayname, password_hash=None - ): + self, + requester: Requester, + localpart: str, + displayname: Optional[str], + password_hash: Optional[str] = None, + ) -> Tuple[str, str]: """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase): class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): """Tests auto-join on remote rooms.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.room_id = "!roomid:remotetest" - async def update_membership(*args, **kwargs): + async def update_membership(*args: Any, **kwargs: Any) -> None: pass - async def lookup_room_alias(*args, **kwargs): + async def lookup_room_alias( + *args: Any, **kwargs: Any + ) -> Tuple[RoomID, List[str]]: return RoomID.from_string(self.room_id), ["remotetest"] self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"]) @@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) - def test_auto_create_auto_join_remote_room(self): + def test_auto_create_auto_join_remote_room(self) -> None: """Tests that we don't attempt to create remote rooms, and that we don't attempt to invite ourselves to rooms we're not in.""" diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index fcde5dab72..df95490d3b 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): ] @override_config({"encryption_enabled_by_default_for_room_type": "all"}) - def test_encrypted_by_default_config_option_all(self): + def test_encrypted_by_default_config_option_all(self) -> None: """Tests that invite-only and non-invite-only rooms have encryption enabled by default when the config option encryption_enabled_by_default_for_room_type is "all". """ @@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) - def test_encrypted_by_default_config_option_invite(self): + def test_encrypted_by_default_config_option_invite(self) -> None: """Tests that only new, invite-only rooms have encryption enabled by default when the config option encryption_enabled_by_default_for_room_type is "invite". """ @@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): ) @override_config({"encryption_enabled_by_default_for_room_type": "off"}) - def test_encrypted_by_default_config_option_off(self): + def test_encrypted_by_default_config_option_off(self) -> None: """Tests that neither new invite-only nor non-invite-only rooms have encryption enabled by default when the config option encryption_enabled_by_default_for_room_type is "off". diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index aa650756e4..d907fcaf04 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from unittest import mock from twisted.internet.defer import ensureDeferred +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import ( EventContentFields, @@ -34,11 +35,14 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock from tests import unittest -def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0): +def _create_event( + room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0 +) -> mock.Mock: result = mock.Mock(name=room_id) result.room_id = room_id result.content = {} @@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i return result -def _order(*events): +def _order(*events: mock.Mock) -> List[mock.Mock]: return sorted(events, key=_child_events_comparison_key) class TestSpaceSummarySort(unittest.TestCase): - def test_no_order_last(self): + def test_no_order_last(self) -> None: """An event with no ordering is placed behind those with an ordering.""" ev1 = _create_event("!abc:test") ev2 = _create_event("!xyz:test", "xyz") self.assertEqual([ev2, ev1], _order(ev1, ev2)) - def test_order(self): + def test_order(self) -> None: """The ordering should be used.""" ev1 = _create_event("!abc:test", "xyz") ev2 = _create_event("!xyz:test", "abc") self.assertEqual([ev2, ev1], _order(ev1, ev2)) - def test_order_origin_server_ts(self): + def test_order_origin_server_ts(self) -> None: """Origin server is a tie-breaker for ordering.""" ev1 = _create_event("!abc:test", origin_server_ts=10) ev2 = _create_event("!xyz:test", origin_server_ts=30) self.assertEqual([ev1, ev2], _order(ev1, ev2)) - def test_order_room_id(self): + def test_order_room_id(self) -> None: """Room ID is a final tie-breaker for ordering.""" ev1 = _create_event("!abc:test") ev2 = _create_event("!xyz:test") self.assertEqual([ev1, ev2], _order(ev1, ev2)) - def test_invalid_ordering_type(self): + def test_invalid_ordering_type(self) -> None: """Invalid orderings are considered the same as missing.""" ev1 = _create_event("!abc:test", 1) ev2 = _create_event("!xyz:test", "xyz") @@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase): ev1 = _create_event("!abc:test", True) self.assertEqual([ev2, ev1], _order(ev1, ev2)) - def test_invalid_ordering_value(self): + def test_invalid_ordering_value(self) -> None: """Invalid orderings are considered the same as missing.""" ev1 = _create_event("!abc:test", "foo\n") ev2 = _create_event("!xyz:test", "xyz") @@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs self.handler = self.hs.get_room_summary_handler() @@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) ) - def test_simple_space(self): + def test_simple_space(self) -> None: """Test a simple space with a single room.""" # The result should have the space and the room in it, along with a link # from space -> room. @@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_large_space(self): + def test_large_space(self) -> None: """Test a space with a large number of rooms.""" rooms = [self.room] # Make at least 51 rooms that are part of the space. @@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result["rooms"] += result2["rooms"] self._assert_hierarchy(result, expected) - def test_visibility(self): + def test_visibility(self) -> None: """A user not in a space cannot inspect it.""" user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") @@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_hierarchy(result2, [(self.space, [self.room])]) def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content + self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any ) -> str: """Create a room with the given join rule and add it to the space.""" room_id = self.helper.create_room_as( @@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(self.space, room_id, self.token) return room_id - def test_filtering(self): + def test_filtering(self) -> None: """ Rooms should be properly filtered to only include rooms the user has access to. """ @@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_complex_space(self): + def test_complex_space(self) -> None: """ Create a "complex" space to see how it handles things like loops and subspaces. """ @@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_pagination(self): + def test_pagination(self) -> None: """Test simple pagination works.""" room_ids = [] for i in range(1, 10): @@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_hierarchy(result, expected) self.assertNotIn("next_batch", result) - def test_invalid_pagination_token(self): + def test_invalid_pagination_token(self) -> None: """An invalid pagination token, or changing other parameters, shoudl be rejected.""" room_ids = [] for i in range(1, 10): @@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_max_depth(self): + def test_max_depth(self) -> None: """Create a deep tree to test the max depth against.""" spaces = [self.space] rooms = [self.room] @@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_hierarchy(result, expected) - def test_unknown_room_version(self): + def test_unknown_room_version(self) -> None: """ If a room with an unknown room version is encountered it should not cause the entire summary to skip. @@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_complex(self): + def test_fed_complex(self) -> None: """ Return data over federation and ensure that it is handled properly. """ @@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "world_readable": True, } - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. @@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_filtering(self): + def test_fed_filtering(self) -> None: """ Rooms returned over federation should be properly filtered to only include rooms the user has access to. @@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ], ) - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. @@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_invited(self): + def test_fed_invited(self) -> None: """ A room which the user was invited to should be included in the response. @@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): }, ) - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return fed_room_entry, {}, set() # Add a room to the space which is on another server. @@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) - def test_fed_caching(self): + def test_fed_caching(self) -> None: """ Federation `/hierarchy` responses should be cached. """ @@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs self.handler = self.hs.get_room_summary_handler() @@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) - def test_own_room(self): + def test_own_room(self) -> None: """Test a simple room created by the requester.""" result = self.get_success(self.handler.get_room_summary(self.user, self.room)) self.assertEqual(result.get("room_id"), self.room) - def test_visibility(self): + def test_visibility(self) -> None: """A user not in a private room cannot get its summary.""" user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") @@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_room_summary(user2, self.room)) self.assertEqual(result.get("room_id"), self.room) - def test_fed(self): + def test_fed(self) -> None: """ Return data over federation and ensure that it is handled properly. """ @@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): {"room_id": fed_room, "world_readable": True}, ) - async def summarize_remote_room_hierarchy(_self, room, suggested_only): + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: return requested_room_entry, {}, set() with mock.patch( diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index a0f84e2940..9b1b8b9f13 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Set, Tuple from unittest.mock import Mock import attr @@ -20,7 +20,9 @@ import attr from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import RedirectException +from synapse.module_api import ModuleApi from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests.test_utils import simple_async_mock @@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. try: import saml2.config + import saml2.response from saml2.sigver import SigverError has_saml2 = True @@ -56,31 +59,39 @@ class FakeAuthnResponse: class TestMappingProvider: - def __init__(self, config, module): + def __init__(self, config: None, module: ModuleApi): pass @staticmethod - def parse_config(config): - return + def parse_config(config: JsonDict) -> None: + return None @staticmethod - def get_saml_attributes(config): + def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]: return {"uid"}, {"displayName"} - def get_remote_user_id(self, saml_response, client_redirect_url): + def get_remote_user_id( + self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str + ) -> str: return saml_response.ava["uid"] def saml_response_to_user_attributes( - self, saml_response, failures, client_redirect_url - ): + self, + saml_response: "saml2.response.AuthnResponse", + failures: int, + client_redirect_url: str, + ) -> dict: localpart = saml_response.ava["username"] + (str(failures) if failures else "") return {"mxid_localpart": localpart, "displayname": None} class TestRedirectMappingProvider(TestMappingProvider): def saml_response_to_user_attributes( - self, saml_response, failures, client_redirect_url - ): + self, + saml_response: "saml2.response.AuthnResponse", + failures: int, + client_redirect_url: str, + ) -> dict: raise RedirectException(b"https://custom-saml-redirect/") @@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) -def _mock_request(): +def _mock_request() -> Mock: """Returns a mock which will stand in as a SynapseRequest""" mock = Mock( spec=[ diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index da4bf8b582..8b6e4a40b6 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import List, Tuple +from typing import Callable, List, Tuple from zope.interface import implementer @@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: - def __init__(self): + def __init__(self) -> None: # (recipient, message) tuples self.messages: List[Tuple[smtp.Address, bytes]] = [] - def receivedHeader(self, helo, origin, recipients): + def receivedHeader( + self, + helo: Tuple[bytes, bytes], + origin: smtp.Address, + recipients: List[smtp.User], + ) -> None: return None - def validateFrom(self, helo, origin): + def validateFrom( + self, helo: Tuple[bytes, bytes], origin: smtp.Address + ) -> smtp.Address: return origin - def record_message(self, recipient: smtp.Address, message: bytes): + def record_message(self, recipient: smtp.Address, message: bytes) -> None: self.messages.append((recipient, message)) - def validateTo(self, user: smtp.User): + def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]: return lambda: _DummyMessage(self, user) @@ -56,20 +63,20 @@ class _DummyMessage: self._user = user self._buffer: List[bytes] = [] - def lineReceived(self, line): + def lineReceived(self, line: bytes) -> None: self._buffer.append(line) - def eomReceived(self): + def eomReceived(self) -> "defer.Deferred[bytes]": message = b"\n".join(self._buffer) + b"\n" self._delivery.record_message(self._user.dest, message) return defer.succeed(b"saved") - def connectionLost(self): + def connectionLost(self) -> None: pass class SendEmailHandlerTestCase(HomeserverTestCase): - def test_send_email(self): + def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() d = ensureDeferred( @@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase): }, } ) - def test_send_email_force_tls(self): + def test_send_email_force_tls(self) -> None: """Happy-path test that we can send email to an Implicit TLS server.""" h = self.hs.get_send_email_handler() d = ensureDeferred( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 05f9ec3c51..f1a50c5bcb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Optional + +from twisted.test.proto_helpers import MemoryReactor + from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage.databases.main import stats +from synapse.util import Clock from tests import unittest @@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() - def _add_background_updates(self): + def _add_background_updates(self) -> None: """ Add the background updates we need to run. """ @@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - async def get_all_room_state(self): + async def get_all_room_state(self) -> List[Dict[str, Any]]: return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) - def _get_current_stats(self, stats_type, stat_id): + def _get_current_stats( + self, stats_type: str, stat_id: str + ) -> Optional[Dict[str, Any]]: table, id_col = stats.TYPE_TO_TABLE[stats_type] cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) @@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def _perform_background_initial_update(self): + def _perform_background_initial_update(self) -> None: # Do the initial population of the stats via the background update self._add_background_updates() self.wait_for_background_updates() - def test_initial_room(self): + def test_initial_room(self) -> None: """ The background updates will build the table from scratch. """ @@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") - def test_create_user(self): + def test_create_user(self) -> None: """ When we create a user, it should have statistics already ready. """ @@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): u1stats = self._get_current_stats("user", u1) - self.assertIsNotNone(u1stats) + assert u1stats is not None # not in any rooms by default self.assertEqual(u1stats["joined_rooms"], 0) - def test_create_room(self): + def test_create_room(self) -> None: """ When we create a room, it should have statistics already ready. """ @@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False) r2stats = self._get_current_stats("room", r2) - self.assertIsNotNone(r1stats) - self.assertIsNotNone(r2stats) + assert r1stats is not None + assert r2stats is not None self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM @@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(r2stats["invited_members"], 0) self.assertEqual(r2stats["banned_members"], 0) - def test_updating_profile_information_does_not_increase_joined_members_count(self): + def test_updating_profile_information_does_not_increase_joined_members_count( + self, + ) -> None: """ Check that the joined_members count does not increase when a user changes their profile information (which is done by sending another join membership event into @@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Get the current room stats r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None # Send a profile update into the room new_profile = {"displayname": "bob"} @@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Get the new room stats r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None # Ensure that the user count did not changed self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"]) @@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"] ) - def test_send_state_event_nonoverwriting(self): + def test_send_state_event_nonoverwriting(self) -> None: """ When we send a non-overwriting state event, it increments current_state_events """ @@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.send_state( r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy" ) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, ) - def test_join_first_time(self): + def test_join_first_time(self) -> None: """ When a user joins a room for the first time, current_state_events and joined_members should increase by exactly 1. @@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): u2token = self.login("u2", "pass") r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.join(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1 ) - def test_join_after_leave(self): + def test_join_after_leave(self) -> None: """ When a user joins a room after being previously left, joined_members should increase by exactly 1. @@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.leave(r1, u2, tok=u2token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.join(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["left_members"] - r1stats_ante["left_members"], -1 ) - def test_invited(self): + def test_invited(self) -> None: """ When a user invites another user, current_state_events and invited_members should increase by exactly 1. @@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): u2 = self.register_user("u2", "pass") r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.invite(r1, u1, u2, tok=u1token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1 ) - def test_join_after_invite(self): + def test_join_after_invite(self) -> None: """ When a user joins a room after being invited and joined_members should increase by exactly 1. @@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.invite(r1, u1, u2, tok=u1token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.join(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1 ) - def test_left(self): + def test_left(self) -> None: """ When a user leaves a room after joining and left_members should increase by exactly 1. @@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.join(r1, u2, tok=u2token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.leave(r1, u2, tok=u2token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 ) - def test_banned(self): + def test_banned(self) -> None: """ When a user is banned from a room after joining and left_members should increase by exactly 1. @@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.join(r1, u2, tok=u2token) r1stats_ante = self._get_current_stats("room", r1) + assert r1stats_ante is not None self.helper.change_membership(r1, u1, u2, "ban", tok=u1token) r1stats_post = self._get_current_stats("room", r1) + assert r1stats_post is not None self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], @@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 ) - def test_initial_background_update(self): + def test_initial_background_update(self) -> None: """ Test that statistics can be generated by the initial background update handler. @@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats = self._get_current_stats("room", r1) u1stats = self._get_current_stats("user", u1) + assert r1stats is not None + assert u1stats is not None + self.assertEqual(r1stats["joined_members"], 1) self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM @@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(u1stats["joined_rooms"], 1) - def test_incomplete_stats(self): + def test_incomplete_stats(self) -> None: """ This tests that we track incomplete statistics. @@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.wait_for_background_updates() r1stats_complete = self._get_current_stats("room", r1) + assert r1stats_complete is not None u1stats_complete = self._get_current_stats("user", u1) + assert u1stats_complete is not None u2stats_complete = self._get_current_stats("user", u2) + assert u2stats_complete is not None # now we make our assertions diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ab5c101eb7..0d9a3de92a 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -14,6 +14,8 @@ from typing import Optional from unittest.mock import MagicMock, Mock, patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import Filtering @@ -23,6 +25,7 @@ from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock import tests.unittest import tests.utils @@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastores().main @@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # modify its config instead of the hs' self.auth_blocking = self.hs.get_auth_blocking() - def test_wait_for_sync_for_user_auth_blocking(self): + def test_wait_for_sync_for_user_auth_blocking(self) -> None: user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = generate_sync_config(user_id1) @@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_unknown_room_version(self): + def test_unknown_room_version(self) -> None: """ A room with an unknown room version should not break sync (and should be excluded). """ @@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.assertNotIn(invite_room, [r.room_id for r in result.invited]) self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) - def test_ban_wins_race_with_join(self): + def test_ban_wins_race_with_join(self) -> None: """Rooms shouldn't appear under "joined" if a join loses a race to a ban. A complicated edge case. Imagine the following scenario: -- cgit 1.5.1 From 864c3f85b0c420f755a064a3c50a45716db3f8af Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 16 Dec 2022 13:04:54 +0000 Subject: Improve type annotations for the helper methods on a `CachedFunction`. (#14685) --- changelog.d/14685.misc | 1 + synapse/util/caches/descriptors.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14685.misc (limited to 'synapse') diff --git a/changelog.d/14685.misc b/changelog.d/14685.misc new file mode 100644 index 0000000000..3ba2270100 --- /dev/null +++ b/changelog.d/14685.misc @@ -0,0 +1 @@ +Improve type annotations for the helper methods on a `CachedFunction`. \ No newline at end of file diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 72227359b9..81df71a0c5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -53,9 +53,9 @@ F = TypeVar("F", bound=Callable[..., Any]) class CachedFunction(Generic[F]): - invalidate: Any = None - invalidate_all: Any = None - prefill: Any = None + invalidate: Callable[[Tuple[Any, ...]], None] + invalidate_all: Callable[[], None] + prefill: Callable[[Tuple[Any, ...], Any], None] cache: Any = None num_args: Any = None -- cgit 1.5.1 From 3aeca2588b79111a48a6083c88efc4d68a2cea19 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 16 Dec 2022 08:53:28 -0500 Subject: Add missing type hints to tests.config. (#14681) --- changelog.d/14681.misc | 1 + mypy.ini | 4 +-- synapse/config/cache.py | 4 +-- synapse/util/caches/lrucache.py | 9 ++--- tests/config/test___main__.py | 6 ++-- tests/config/test_background_update.py | 4 +-- tests/config/test_base.py | 10 +++--- tests/config/test_cache.py | 57 ++++++++++++++++---------------- tests/config/test_database.py | 2 +- tests/config/test_generate.py | 8 ++--- tests/config/test_load.py | 12 +++---- tests/config/test_ratelimiting.py | 2 +- tests/config/test_registration_config.py | 4 +-- tests/config/test_room_directory.py | 4 +-- tests/config/test_server.py | 18 +++++----- tests/config/test_tls.py | 53 +++++++++++++++++------------ tests/config/test_util.py | 2 +- tests/config/utils.py | 11 +++--- 18 files changed, 108 insertions(+), 103 deletions(-) create mode 100644 changelog.d/14681.misc (limited to 'synapse') diff --git a/changelog.d/14681.misc b/changelog.d/14681.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14681.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 1a37414e58..80fbcdfeab 100644 --- a/mypy.ini +++ b/mypy.ini @@ -36,8 +36,6 @@ exclude = (?x) |tests/api/test_ratelimiting.py |tests/app/test_openid_listener.py |tests/appservice/test_scheduler.py - |tests/config/test_cache.py - |tests/config/test_tls.py |tests/crypto/test_keyring.py |tests/events/test_presence_router.py |tests/events/test_utils.py @@ -89,7 +87,7 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False -[mypy-tests.config.test_api] +[mypy-tests.config.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/synapse/config/cache.py b/synapse/config/cache.py index eb4194a5a9..015b2a138e 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -16,7 +16,7 @@ import logging import os import re import threading -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Mapping, Optional import attr @@ -94,7 +94,7 @@ def add_resizable_cache( class CacheConfig(Config): section = "caches" - _environ = os.environ + _environ: Mapping[str, str] = os.environ event_cache_size: int cache_factors: Dict[str, float] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index dcf0eac3bf..452d5d04c1 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -788,26 +788,21 @@ class LruCache(Generic[KT, VT]): def __contains__(self, key: KT) -> bool: return self.contains(key) - def set_cache_factor(self, factor: float) -> bool: + def set_cache_factor(self, factor: float) -> None: """ Set the cache factor for this individual cache. This will trigger a resize if it changes, which may require evicting items from the cache. - - Returns: - Whether the cache changed size or not. """ if not self.apply_cache_factor_from_config: - return False + return new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size if self._on_resize: self._on_resize() - return True - return False def __del__(self) -> None: # We're about to be deleted, so we make sure to clear up all the nodes diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py index b1c73d3612..cb5d4b05c3 100644 --- a/tests/config/test___main__.py +++ b/tests/config/test___main__.py @@ -17,15 +17,15 @@ from tests.config.utils import ConfigFileTestCase class ConfigMainFileTestCase(ConfigFileTestCase): - def test_executes_without_an_action(self): + def test_executes_without_an_action(self) -> None: self.generate_config() main(["", "-c", self.config_file]) - def test_read__error_if_key_not_found(self): + def test_read__error_if_key_not_found(self) -> None: self.generate_config() with self.assertRaises(SystemExit): main(["", "read", "foo.bar.hello", "-c", self.config_file]) - def test_read__passes_if_key_found(self): + def test_read__passes_if_key_found(self) -> None: self.generate_config() main(["", "read", "server.server_name", "-c", self.config_file]) diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py index 0c32c1ca29..e4bad2ba6e 100644 --- a/tests/config/test_background_update.py +++ b/tests/config/test_background_update.py @@ -22,7 +22,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase): # Tests that the default values in the config are correctly loaded. Note that the default # values are loaded when the corresponding config options are commented out, which is why there isn't # a config specified here. - def test_default_configuration(self): + def test_default_configuration(self) -> None: background_updater = BackgroundUpdater( self.hs, self.hs.get_datastores().main.db_pool ) @@ -46,7 +46,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase): """ ) ) - def test_custom_configuration(self): + def test_custom_configuration(self) -> None: background_updater = BackgroundUpdater( self.hs, self.hs.get_datastores().main.db_pool ) diff --git a/tests/config/test_base.py b/tests/config/test_base.py index 6a52f862f4..3fbfe6c1da 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -24,13 +24,13 @@ from tests import unittest class BaseConfigTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # The root object needs a server property with a public_baseurl. root = Mock() root.server.public_baseurl = "http://test" self.config = Config(root) - def test_loading_missing_templates(self): + def test_loading_missing_templates(self) -> None: # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: @@ -50,7 +50,7 @@ class BaseConfigTestCase(unittest.TestCase): "Template file did not contain our test string", ) - def test_loading_custom_templates(self): + def test_loading_custom_templates(self) -> None: # Use a temporary directory that exists on the system with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary bogus template file @@ -79,7 +79,7 @@ class BaseConfigTestCase(unittest.TestCase): "Template file did not contain our test string", ) - def test_multiple_custom_template_directories(self): + def test_multiple_custom_template_directories(self) -> None: """Tests that directories are searched in the right order if multiple custom template directories are provided. """ @@ -137,7 +137,7 @@ class BaseConfigTestCase(unittest.TestCase): for td in tempdirs: td.cleanup() - def test_loading_template_from_nonexistent_custom_directory(self): + def test_loading_template_from_nonexistent_custom_directory(self) -> None: with self.assertRaises(ConfigError): self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index d2b3c299e3..96f66af328 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -13,26 +13,27 @@ # limitations under the License. from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.types import JsonDict from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase class CacheConfigTests(TestCase): - def setUp(self): + def setUp(self) -> None: # Reset caches before each test since there's global state involved. self.config = CacheConfig() self.config.reset() - def tearDown(self): + def tearDown(self) -> None: # Also reset the caches after each test to leave state pristine. self.config.reset() - def test_individual_caches_from_environ(self): + def test_individual_caches_from_environ(self) -> None: """ Individual cache factors will be loaded from the environment. """ - config = {} + config: JsonDict = {} self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", @@ -42,15 +43,15 @@ class CacheConfigTests(TestCase): self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) - def test_config_overrides_environ(self): + def test_config_overrides_environ(self) -> None: """ Individual cache factors defined in the environment will take precedence over those in the config. """ - config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", - "SYNAPSE_CACHE_FACTOR_FOO": 1, + "SYNAPSE_CACHE_FACTOR_FOO": "1", } self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() @@ -60,104 +61,104 @@ class CacheConfigTests(TestCase): {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) - def test_individual_instantiated_before_config_load(self): + def test_individual_instantiated_before_config_load(self) -> None: """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized once the config is loaded. """ - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) - config = {"caches": {"per_cache_factors": {"foo": 3}}} + config: JsonDict = {"caches": {"per_cache_factors": {"foo": 3}}} self.config.read_config(config) self.config.resize_all_caches() self.assertEqual(cache.max_size, 300) - def test_individual_instantiated_after_config_load(self): + def test_individual_instantiated_after_config_load(self) -> None: """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the per_cache_factor if there is one. """ - config = {"caches": {"per_cache_factors": {"foo": 2}}} + config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2}}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 200) - def test_global_instantiated_before_config_load(self): + def test_global_instantiated_before_config_load(self) -> None: """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized to the new default cache size once the config is loaded. """ - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) - config = {"caches": {"global_factor": 4}} + config: JsonDict = {"caches": {"global_factor": 4}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() self.assertEqual(cache.max_size, 400) - def test_global_instantiated_after_config_load(self): + def test_global_instantiated_after_config_load(self) -> None: """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the global factor if there is no per-cache factor. """ - config = {"caches": {"global_factor": 1.5}} + config: JsonDict = {"caches": {"global_factor": 1.5}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache = LruCache(100) + cache: LruCache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) - def test_cache_with_asterisk_in_name(self): + def test_cache_with_asterisk_in_name(self) -> None: """Some caches have asterisks in their name, test that they are set correctly.""" - config = { + config: JsonDict = { "caches": { "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", - "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, + "SYNAPSE_CACHE_FACTOR_CACHE_B": "3", } self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache_a = LruCache(100) + cache_a: LruCache = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) self.assertEqual(cache_a.max_size, 200) - cache_b = LruCache(100) + cache_b: LruCache = LruCache(100) add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) self.assertEqual(cache_b.max_size, 300) - cache_c = LruCache(100) + cache_c: LruCache = LruCache(100) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) self.assertEqual(cache_c.max_size, 200) - def test_apply_cache_factor_from_config(self): + def test_apply_cache_factor_from_config(self) -> None: """Caches can disable applying cache factor updates, mainly used by event cache size. """ - config = {"caches": {"event_cache_size": "10k"}} + config: JsonDict = {"caches": {"event_cache_size": "10k"}} self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache = LruCache( + cache: LruCache = LruCache( max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) diff --git a/tests/config/test_database.py b/tests/config/test_database.py index 9eca10bbe9..240277bcc6 100644 --- a/tests/config/test_database.py +++ b/tests/config/test_database.py @@ -20,7 +20,7 @@ from tests import unittest class DatabaseConfigTestCase(unittest.TestCase): - def test_database_configured_correctly(self): + def test_database_configured_correctly(self) -> None: conf = yaml.safe_load( DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") ) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index fdfbb0e38e..3a02366932 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -25,14 +25,14 @@ from tests import unittest class ConfigGenerationTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, "homeserver.yaml") - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.dir) - def test_generate_config_generates_files(self): + def test_generate_config_generates_files(self) -> None: with redirect_stdout(StringIO()): HomeServerConfig.load_or_generate_config( "", @@ -56,7 +56,7 @@ class ConfigGenerationTestCase(unittest.TestCase): os.path.join(os.getcwd(), "homeserver.log"), ) - def assert_log_filename_is(self, log_config_file, expected): + def assert_log_filename_is(self, log_config_file: str, expected: str) -> None: with open(log_config_file) as f: config = f.read() # find the 'filename' line diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 69a4e9413b..fcbe79cc7a 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -21,14 +21,14 @@ from tests.config.utils import ConfigFileTestCase class ConfigLoadingFileTestCase(ConfigFileTestCase): - def test_load_fails_if_server_name_missing(self): + def test_load_fails_if_server_name_missing(self) -> None: self.generate_config_and_remove_lines_containing("server_name") with self.assertRaises(ConfigError): HomeServerConfig.load_config("", ["-c", self.config_file]) with self.assertRaises(ConfigError): HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) - def test_generates_and_loads_macaroon_secret_key(self): + def test_generates_and_loads_macaroon_secret_key(self) -> None: self.generate_config() with open(self.config_file) as f: @@ -58,7 +58,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): "was: %r" % (config2.key.macaroon_secret_key,) ) - def test_load_succeeds_if_macaroon_secret_key_missing(self): + def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None: self.generate_config_and_remove_lines_containing("macaroon") config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) @@ -73,7 +73,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): config1.key.macaroon_secret_key, config3.key.macaroon_secret_key ) - def test_disable_registration(self): + def test_disable_registration(self) -> None: self.generate_config() self.add_lines_to_config( ["enable_registration: true", "disable_registration: true"] @@ -93,7 +93,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): assert config3 is not None self.assertTrue(config3.registration.enable_registration) - def test_stats_enabled(self): + def test_stats_enabled(self) -> None: self.generate_config_and_remove_lines_containing("enable_metrics") self.add_lines_to_config(["enable_metrics: true"]) @@ -101,7 +101,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.metrics.metrics_flags.known_servers) - def test_depreciated_identity_server_flag_throws_error(self): + def test_depreciated_identity_server_flag_throws_error(self) -> None: self.generate_config() # Needed to ensure that actual key/value pair added below don't end up on a line with a comment self.add_lines_to_config([" "]) diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index 1b63e1adfd..f12147eaa0 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -18,7 +18,7 @@ from tests.utils import default_config class RatelimitConfigTestCase(TestCase): - def test_parse_rc_federation(self): + def test_parse_rc_federation(self) -> None: config_dict = default_config("test") config_dict["rc_federation"] = { "window_size": 20000, diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py index 33d7b70e32..f6869d7f06 100644 --- a/tests/config/test_registration_config.py +++ b/tests/config/test_registration_config.py @@ -21,7 +21,7 @@ from tests.utils import default_config class RegistrationConfigTestCase(ConfigFileTestCase): - def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): + def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self) -> None: """ session_lifetime should logically be larger than, or at least as large as, all the different token lifetimes. @@ -91,7 +91,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase): "", ) - def test_refuse_to_start_if_open_registration_and_no_verification(self): + def test_refuse_to_start_if_open_registration_and_no_verification(self) -> None: self.generate_config() self.add_lines_to_config( [ diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index db745815ef..297ab37792 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -20,7 +20,7 @@ from tests import unittest class RoomDirectoryConfigTestCase(unittest.TestCase): - def test_alias_creation_acl(self): + def test_alias_creation_acl(self) -> None: config = yaml.safe_load( """ alias_creation_rules: @@ -78,7 +78,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): ) ) - def test_room_publish_acl(self): + def test_room_publish_acl(self) -> None: config = yaml.safe_load( """ alias_creation_rules: [] diff --git a/tests/config/test_server.py b/tests/config/test_server.py index 1f27a54701..41a3fb0b6d 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -21,7 +21,7 @@ from tests import unittest class ServerConfigTestCase(unittest.TestCase): - def test_is_threepid_reserved(self): + def test_is_threepid_reserved(self) -> None: user1 = {"medium": "email", "address": "user1@example.com"} user2 = {"medium": "email", "address": "user2@example.com"} user3 = {"medium": "email", "address": "user3@example.com"} @@ -32,7 +32,7 @@ class ServerConfigTestCase(unittest.TestCase): self.assertFalse(is_threepid_reserved(config, user3)) self.assertFalse(is_threepid_reserved(config, user1_msisdn)) - def test_unsecure_listener_no_listeners_open_private_ports_false(self): + def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None: conf = yaml.safe_load( ServerConfig().generate_config_section( "CONFDIR", "/data_dir_path", "che.org", False, None @@ -52,7 +52,7 @@ class ServerConfigTestCase(unittest.TestCase): self.assertEqual(conf["listeners"], expected_listeners) - def test_unsecure_listener_no_listeners_open_private_ports_true(self): + def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None: conf = yaml.safe_load( ServerConfig().generate_config_section( "CONFDIR", "/data_dir_path", "che.org", True, None @@ -71,7 +71,7 @@ class ServerConfigTestCase(unittest.TestCase): self.assertEqual(conf["listeners"], expected_listeners) - def test_listeners_set_correctly_open_private_ports_false(self): + def test_listeners_set_correctly_open_private_ports_false(self) -> None: listeners = [ { "port": 8448, @@ -95,7 +95,7 @@ class ServerConfigTestCase(unittest.TestCase): self.assertEqual(conf["listeners"], listeners) - def test_listeners_set_correctly_open_private_ports_true(self): + def test_listeners_set_correctly_open_private_ports_true(self) -> None: listeners = [ { "port": 8448, @@ -131,14 +131,14 @@ class ServerConfigTestCase(unittest.TestCase): class GenerateIpSetTestCase(unittest.TestCase): - def test_empty(self): + def test_empty(self) -> None: ip_set = generate_ip_set(()) self.assertFalse(ip_set) ip_set = generate_ip_set((), ()) self.assertFalse(ip_set) - def test_generate(self): + def test_generate(self) -> None: """Check adding IPv4 and IPv6 addresses.""" # IPv4 address ip_set = generate_ip_set(("1.2.3.4",)) @@ -160,7 +160,7 @@ class GenerateIpSetTestCase(unittest.TestCase): ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4")) self.assertEqual(len(ip_set.iter_cidrs()), 4) - def test_extra(self): + def test_extra(self) -> None: """Extra IP addresses are treated the same.""" ip_set = generate_ip_set((), ("1.2.3.4",)) self.assertEqual(len(ip_set.iter_cidrs()), 4) @@ -172,7 +172,7 @@ class GenerateIpSetTestCase(unittest.TestCase): ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",)) self.assertEqual(len(ip_set.iter_cidrs()), 4) - def test_bad_value(self): + def test_bad_value(self) -> None: """An error should be raised if a bad value is passed in.""" with self.assertRaises(ConfigError): generate_ip_set(("not-an-ip",)) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 9ba5781573..7510fc4643 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -13,13 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast + import idna from OpenSSL import SSL from synapse.config._base import Config, RootConfig +from synapse.config.homeserver import HomeServerConfig from synapse.config.tls import ConfigError, TlsConfig -from synapse.crypto.context_factory import FederationPolicyForHTTPS +from synapse.crypto.context_factory import ( + FederationPolicyForHTTPS, + SSLClientConnectionCreator, +) +from synapse.types import JsonDict from tests.unittest import TestCase @@ -27,7 +34,7 @@ from tests.unittest import TestCase class FakeServer(Config): section = "server" - def has_tls_listener(self): + def has_tls_listener(self) -> bool: return False @@ -36,21 +43,21 @@ class TestConfig(RootConfig): class TLSConfigTests(TestCase): - def test_tls_client_minimum_default(self): + def test_tls_client_minimum_default(self) -> None: """ The default client TLS version is 1.0. """ - config = {} + config: JsonDict = {} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") - def test_tls_client_minimum_set(self): + def test_tls_client_minimum_set(self) -> None: """ The default client TLS version can be set to 1.0, 1.1, and 1.2. """ - config = {"federation_client_minimum_tls_version": 1} + config: JsonDict = {"federation_client_minimum_tls_version": 1} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") @@ -76,7 +83,7 @@ class TLSConfigTests(TestCase): t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") - def test_tls_client_minimum_1_point_3_missing(self): + def test_tls_client_minimum_1_point_3_missing(self) -> None: """ If TLS 1.3 support is missing and it's configured, it will raise a ConfigError. @@ -88,7 +95,7 @@ class TLSConfigTests(TestCase): self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3) assert not hasattr(SSL, "OP_NO_TLSv1_3") - config = {"federation_client_minimum_tls_version": 1.3} + config: JsonDict = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: t.tls.read_config(config, config_dir_path="", data_dir_path="") @@ -100,7 +107,7 @@ class TLSConfigTests(TestCase): ), ) - def test_tls_client_minimum_1_point_3_exists(self): + def test_tls_client_minimum_1_point_3_exists(self) -> None: """ If TLS 1.3 support exists and it's configured, it will be settable. """ @@ -110,20 +117,20 @@ class TLSConfigTests(TestCase): self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3")) assert hasattr(SSL, "OP_NO_TLSv1_3") - config = {"federation_client_minimum_tls_version": 1.3} + config: JsonDict = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") - def test_tls_client_minimum_set_passed_through_1_2(self): + def test_tls_client_minimum_set_passed_through_1_2(self) -> None: """ The configured TLS version is correctly configured by the ContextFactory. """ - config = {"federation_client_minimum_tls_version": 1.2} + config: JsonDict = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") - cf = FederationPolicyForHTTPS(t) + cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t)) options = _get_ssl_context_options(cf._verify_ssl_context) # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 @@ -131,15 +138,15 @@ class TLSConfigTests(TestCase): self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) - def test_tls_client_minimum_set_passed_through_1_0(self): + def test_tls_client_minimum_set_passed_through_1_0(self) -> None: """ The configured TLS version is correctly configured by the ContextFactory. """ - config = {"federation_client_minimum_tls_version": 1} + config: JsonDict = {"federation_client_minimum_tls_version": 1} t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") - cf = FederationPolicyForHTTPS(t) + cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t)) options = _get_ssl_context_options(cf._verify_ssl_context) # The context has not had any of the NO_TLS set. @@ -147,11 +154,11 @@ class TLSConfigTests(TestCase): self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) - def test_whitelist_idna_failure(self): + def test_whitelist_idna_failure(self) -> None: """ The federation certificate whitelist will not allow IDNA domain names. """ - config = { + config: JsonDict = { "federation_certificate_verification_whitelist": [ "example.com", "*.ドメイン.テスト", @@ -163,11 +170,11 @@ class TLSConfigTests(TestCase): ) self.assertIn("IDNA domain names", str(e)) - def test_whitelist_idna_result(self): + def test_whitelist_idna_result(self) -> None: """ The federation certificate whitelist will match on IDNA encoded names. """ - config = { + config: JsonDict = { "federation_certificate_verification_whitelist": [ "example.com", "*.xn--eckwd4c7c.xn--zckzah", @@ -176,14 +183,16 @@ class TLSConfigTests(TestCase): t = TestConfig() t.tls.read_config(config, config_dir_path="", data_dir_path="") - cf = FederationPolicyForHTTPS(t) + cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t)) # Not in the whitelist opts = cf.get_options(b"notexample.com") + assert isinstance(opts, SSLClientConnectionCreator) self.assertTrue(opts._verifier._verify_certs) # Caught by the wildcard opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) + assert isinstance(opts, SSLClientConnectionCreator) self.assertFalse(opts._verifier._verify_certs) @@ -191,4 +200,4 @@ def _get_ssl_context_options(ssl_context: SSL.Context) -> int: """get the options bits from an openssl context object""" # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to # use the low-level interface - return SSL._lib.SSL_CTX_get_options(ssl_context._context) + return SSL._lib.SSL_CTX_get_options(ssl_context._context) # type: ignore[attr-defined] diff --git a/tests/config/test_util.py b/tests/config/test_util.py index 3d4929daac..7073654832 100644 --- a/tests/config/test_util.py +++ b/tests/config/test_util.py @@ -21,7 +21,7 @@ from tests.unittest import TestCase class ValidateConfigTestCase(TestCase): """Test cases for synapse.config._util.validate_config""" - def test_bad_object_in_array(self): + def test_bad_object_in_array(self) -> None: """malformed objects within an array should be validated correctly""" # consider a structure: diff --git a/tests/config/utils.py b/tests/config/utils.py index 94c18a052b..4c0e8a064a 100644 --- a/tests/config/utils.py +++ b/tests/config/utils.py @@ -17,19 +17,20 @@ import tempfile import unittest from contextlib import redirect_stdout from io import StringIO +from typing import List from synapse.config.homeserver import HomeServerConfig class ConfigFileTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dir = tempfile.mkdtemp() self.config_file = os.path.join(self.dir, "homeserver.yaml") - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.dir) - def generate_config(self): + def generate_config(self) -> None: with redirect_stdout(StringIO()): HomeServerConfig.load_or_generate_config( "", @@ -43,7 +44,7 @@ class ConfigFileTestCase(unittest.TestCase): ], ) - def generate_config_and_remove_lines_containing(self, needle): + def generate_config_and_remove_lines_containing(self, needle: str) -> None: self.generate_config() with open(self.config_file) as f: @@ -52,7 +53,7 @@ class ConfigFileTestCase(unittest.TestCase): with open(self.config_file, "w") as f: f.write("".join(contents)) - def add_lines_to_config(self, lines): + def add_lines_to_config(self, lines: List[str]) -> None: with open(self.config_file, "a") as f: for line in lines: f.write(line + "\n") -- cgit 1.5.1 From 2888d7ec83b33b3ce848d9219c921ffe0b88ffbf Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 19 Dec 2022 14:57:51 +0000 Subject: Faster remote room joins: invalidate caches and unblock requests when receiving un-partial-stated event notifications over replication. [rei:frrj/streams/unpsr] (#14546) --- changelog.d/14546.misc | 1 + synapse/replication/tcp/client.py | 14 ++++++++++++- synapse/storage/databases/main/events_worker.py | 27 ++++++++++++++----------- synapse/storage/databases/main/state.py | 18 ++++++++++++++++- 4 files changed, 46 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14546.misc (limited to 'synapse') diff --git a/changelog.d/14546.misc b/changelog.d/14546.misc new file mode 100644 index 0000000000..60b6761a51 --- /dev/null +++ b/changelog.d/14546.misc @@ -0,0 +1 @@ +Faster remote room joins: stream the un-partial-stating of events over replication. \ No newline at end of file diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b4dad47b45..658d89210d 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -36,6 +36,7 @@ from synapse.replication.tcp.streams import ( TagAccountDataStream, ToDeviceStream, TypingStream, + UnPartialStatedEventStream, UnPartialStatedRoomStream, ) from synapse.replication.tcp.streams.events import ( @@ -43,7 +44,10 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStreamRow +from synapse.replication.tcp.streams.partial_state import ( + UnPartialStatedEventStreamRow, + UnPartialStatedRoomStreamRow, +) from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID from synapse.util.async_helpers import Linearizer, timeout_deferred from synapse.util.metrics import Measure @@ -247,6 +251,14 @@ class ReplicationDataHandler: self._state_storage_controller.notify_room_un_partial_stated( row.room_id ) + elif stream_name == UnPartialStatedEventStream.NAME: + for row in rows: + assert isinstance(row, UnPartialStatedEventStreamRow) + + # Wake up any tasks waiting for the event to be un-partial-stated. + self._state_storage_controller.notify_event_un_partial_stated( + row.event_id + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e19b16064b..761b15a815 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,8 +59,9 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.tcp.streams import BackfillStream +from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream from synapse.replication.tcp.streams.events import EventsStream +from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -391,6 +392,16 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) + elif stream_name == UnPartialStatedEventStream.NAME: + for row in rows: + assert isinstance(row, UnPartialStatedEventStreamRow) + + self.is_partial_state_event.invalidate((row.event_id,)) + + if row.rejection_status_changed: + # If the partial-stated event became rejected or unrejected + # when it wasn't before, we need to invalidate this cache. + self._invalidate_local_get_event_cache(row.event_id) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -2380,6 +2391,9 @@ class EventsWorkerStore(SQLBaseStore): This can happen, for example, when resyncing state during a faster join. + It is the caller's responsibility to ensure that other workers are + sent a notification so that they call `_invalidate_local_get_event_cache()`. + Args: txn: event_id: ID of event to update @@ -2418,14 +2432,3 @@ class EventsWorkerStore(SQLBaseStore): ) self.invalidate_get_event_cache_after_txn(txn, event_id) - - # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just - # call '_send_invalidation_to_replication', but we actually need the other - # end to call _invalidate_local_get_event_cache() rather than (just) - # _get_event_cache.invalidate(). - # - # One solution might be to (somehow) get the workers to call - # _invalidate_caches_for_event() (though that will invalidate more than - # strictly necessary). - # - # https://github.com/matrix-org/synapse/issues/12994 diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index f855903c39..f32cbb2dec 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,7 @@ # limitations under the License. import collections.abc import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple import attr @@ -24,6 +24,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.opentracing import trace +from synapse.replication.tcp.streams import UnPartialStatedEventStream +from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -82,6 +84,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) self._instance_name: str = hs.get_instance_name() + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == UnPartialStatedEventStream.NAME: + for row in rows: + assert isinstance(row, UnPartialStatedEventStreamRow) + self._get_state_group_for_event.invalidate((row.event_id,)) + + super().process_replication_rows(stream_name, instance_name, token, rows) + async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room Raises: -- cgit 1.5.1 From 7010a3d0151b88b3a9a7451201eaf9c5bbe48d64 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 21 Dec 2022 13:05:21 -0500 Subject: Switch to ruff instead of flake8. (#14633) ruff is a flake8-compatible Python linter written in Rust. It supports the flake8 plugins that we use and is significantly faster in testing. --- .flake8 | 18 ----- .github/workflows/tests.yml | 2 +- changelog.d/14633.misc | 1 + poetry.lock | 119 +++++++++------------------------- pyproject.toml | 46 +++++++++++-- scripts-dev/lint.sh | 5 +- stubs/frozendict.pyi | 2 + stubs/icu.pyi | 2 + stubs/sortedcontainers/sorteddict.pyi | 2 + stubs/sortedcontainers/sortedlist.pyi | 2 + stubs/sortedcontainers/sortedset.pyi | 2 + synapse/config/_base.pyi | 2 + 12 files changed, 87 insertions(+), 116 deletions(-) delete mode 100644 .flake8 create mode 100644 changelog.d/14633.misc (limited to 'synapse') diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 4c6a4d5843..0000000000 --- a/.flake8 +++ /dev/null @@ -1,18 +0,0 @@ -# TODO: incorporate this into pyproject.toml if flake8 supports it in the future. -# See https://github.com/PyCQA/flake8/issues/234 -[flake8] -# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes -# for error codes. The ones we ignore are: -# W503: line break before binary operator -# W504: line break after binary operator -# E203: whitespace before ':' (which is contrary to pep8?) -# E731: do not assign a lambda expression, use a def -# E501: Line too long (black enforces this for us) -# -# flake8-bugbear runs extra checks. Its error codes are described at -# https://github.com/PyCQA/flake8-bugbear#list-of-warnings -# B019: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks -# B023: Functions defined inside a loop must not use variables redefined in the loop -# B024: Abstract base class with no abstract method. - -ignore=W503,W504,E203,E731,E501,B019,B023,B024 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f07655d982..5a0c0a0d65 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -53,7 +53,7 @@ jobs: - run: scripts-dev/check_schema_delta.py --force-colors lint: - uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1" + uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v2" with: typechecking-extras: "all" diff --git a/changelog.d/14633.misc b/changelog.d/14633.misc new file mode 100644 index 0000000000..def187b12b --- /dev/null +++ b/changelog.d/14633.misc @@ -0,0 +1 @@ +Use [ruff](https://github.com/charliermarsh/ruff/) instead of flake8. diff --git a/poetry.lock b/poetry.lock index 9a9a141a14..c83cad3e1a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -244,47 +244,6 @@ python-versions = ">=3.7" [package.extras] dev = ["Sphinx", "coverage", "flake8", "lxml", "memory-profiler", "mypy (==0.910)", "tox", "xmlschema (>=1.8.0)"] -[[package]] -name = "flake8" -version = "5.0.4" -description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" -optional = false -python-versions = ">=3.6.1" - -[package.dependencies] -importlib-metadata = {version = ">=1.1.0,<4.3", markers = "python_version < \"3.8\""} -mccabe = ">=0.7.0,<0.8.0" -pycodestyle = ">=2.9.0,<2.10.0" -pyflakes = ">=2.5.0,<2.6.0" - -[[package]] -name = "flake8-bugbear" -version = "22.12.6" -description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." -category = "dev" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -attrs = ">=19.2.0" -flake8 = ">=3.0.0" - -[package.extras] -dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit", "tox"] - -[[package]] -name = "flake8-comprehensions" -version = "3.10.1" -description = "A flake8 plugin to help you write better list/set/dict comprehensions." -category = "dev" -optional = false -python-versions = ">=3.7" - -[package.dependencies] -flake8 = ">=3.0,<3.2.0 || >3.2.0" -importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} - [[package]] name = "frozendict" version = "2.3.4" @@ -553,14 +512,6 @@ Twisted = ">=15.1.0" [package.extras] dev = ["black (==22.3.0)", "flake8 (==4.0.1)", "isort (==5.9.3)", "ldaptor", "matrix-synapse", "mypy (==0.910)", "tox", "types-setuptools"] -[[package]] -name = "mccabe" -version = "0.7.0" -description = "McCabe checker, plugin for flake8" -category = "dev" -optional = false -python-versions = ">=3.6" - [[package]] name = "msgpack" version = "1.0.4" @@ -770,14 +721,6 @@ python-versions = "*" [package.dependencies] pyasn1 = ">=0.4.6,<0.5.0" -[[package]] -name = "pycodestyle" -version = "2.9.1" -description = "Python style guide checker" -category = "dev" -optional = false -python-versions = ">=3.6" - [[package]] name = "pycparser" version = "2.21" @@ -801,14 +744,6 @@ typing-extensions = ">=4.1.0" dotenv = ["python-dotenv (>=0.10.4)"] email = ["email-validator (>=1.0.3)"] -[[package]] -name = "pyflakes" -version = "2.5.0" -description = "passive checker of Python programs" -category = "dev" -optional = false -python-versions = ">=3.6" - [[package]] name = "pygithub" version = "1.57" @@ -1044,6 +979,14 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] +[[package]] +name = "ruff" +version = "0.0.189" +description = "An extremely fast Python linter, written in Rust." +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "secretstorage" version = "3.3.1" @@ -1635,7 +1578,7 @@ user-search = ["pyicu"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "f20007013f33bc35a01e412c48adc62a936030f3074e06286674c5ad7f44d300" +content-hash = "d20b6aea682a74e6a161080bb459e73160b8eb79526f5d17a525639ac3fe3e9e" [metadata.files] attrs = [ @@ -1827,18 +1770,6 @@ elementpath = [ {file = "elementpath-2.5.0-py3-none-any.whl", hash = "sha256:2a432775e37a19e4362443078130a7dbfc457d7d093cd421c03958d9034cc08b"}, {file = "elementpath-2.5.0.tar.gz", hash = "sha256:3a27aaf3399929fccda013899cb76d3ff111734abf4281e5f9d3721ba0b9ffa3"}, ] -flake8 = [ - {file = "flake8-5.0.4-py2.py3-none-any.whl", hash = "sha256:7a1cf6b73744f5806ab95e526f6f0d8c01c66d7bbe349562d22dfca20610b248"}, - {file = "flake8-5.0.4.tar.gz", hash = "sha256:6fbe320aad8d6b95cec8b8e47bc933004678dc63095be98528b7bdd2a9f510db"}, -] -flake8-bugbear = [ - {file = "flake8-bugbear-22.12.6.tar.gz", hash = "sha256:4cdb2c06e229971104443ae293e75e64c6107798229202fbe4f4091427a30ac0"}, - {file = "flake8_bugbear-22.12.6-py3-none-any.whl", hash = "sha256:b69a510634f8a9c298dfda2b18a8036455e6b19ecac4fe582e4d7a0abfa50a30"}, -] -flake8-comprehensions = [ - {file = "flake8-comprehensions-3.10.1.tar.gz", hash = "sha256:412052ac4a947f36b891143430fef4859705af11b2572fbb689f90d372cf26ab"}, - {file = "flake8_comprehensions-3.10.1-py3-none-any.whl", hash = "sha256:d763de3c74bc18a79c039a7ec732e0a1985b0c79309ceb51e56401ad0a2cd44e"}, -] frozendict = [ {file = "frozendict-2.3.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4a3b32d47282ae0098b9239a6d53ec539da720258bd762d62191b46f2f87c5fc"}, {file = "frozendict-2.3.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84c9887179a245a66a50f52afa08d4d92ae0f269839fab82285c70a0fa0dd782"}, @@ -2046,6 +1977,7 @@ lxml = [ {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"}, {file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"}, + {file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"}, {file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"}, @@ -2055,6 +1987,7 @@ lxml = [ {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"}, {file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"}, + {file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"}, {file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"}, @@ -2147,10 +2080,6 @@ matrix-synapse-ldap3 = [ {file = "matrix-synapse-ldap3-0.2.2.tar.gz", hash = "sha256:b388d95693486eef69adaefd0fd9e84463d52fe17b0214a00efcaa669b73cb74"}, {file = "matrix_synapse_ldap3-0.2.2-py3-none-any.whl", hash = "sha256:66ee4c85d7952c6c27fd04c09cdfdf4847b8e8b7d6a7ada6ba1100013bda060f"}, ] -mccabe = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, -] msgpack = [ {file = "msgpack-1.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4ab251d229d10498e9a2f3b1e68ef64cb393394ec477e3370c457f9430ce9250"}, {file = "msgpack-1.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:112b0f93202d7c0fef0b7810d465fde23c746a2d482e1e2de2aafd2ce1492c88"}, @@ -2370,10 +2299,6 @@ pyasn1-modules = [ {file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"}, {file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"}, ] -pycodestyle = [ - {file = "pycodestyle-2.9.1-py2.py3-none-any.whl", hash = "sha256:d1735fc58b418fd7c5f658d28d943854f8a849b01a5d0a1e6f3f3fdd0166804b"}, - {file = "pycodestyle-2.9.1.tar.gz", hash = "sha256:2c9607871d58c76354b697b42f5d57e1ada7d261c261efac224b664affdc5785"}, -] pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, @@ -2416,10 +2341,6 @@ pydantic = [ {file = "pydantic-1.10.2-py3-none-any.whl", hash = "sha256:1b6ee725bd6e83ec78b1aa32c5b1fa67a3a65badddde3976bca5fe4568f27709"}, {file = "pydantic-1.10.2.tar.gz", hash = "sha256:91b8e218852ef6007c2b98cd861601c6a09f1aa32bbbb74fab5b1c33d4a1e410"}, ] -pyflakes = [ - {file = "pyflakes-2.5.0-py2.py3-none-any.whl", hash = "sha256:4579f67d887f804e67edb544428f264b7b24f435b263c4614f384135cea553d2"}, - {file = "pyflakes-2.5.0.tar.gz", hash = "sha256:491feb020dca48ccc562a8c0cbe8df07ee13078df59813b83959cbdada312ea3"}, -] pygithub = [ {file = "PyGithub-1.57-py3-none-any.whl", hash = "sha256:5822febeac2391f1306c55a99af2bc8f86c8bf82ded000030cd02c18f31b731f"}, {file = "PyGithub-1.57.tar.gz", hash = "sha256:c273f252b278fb81f1769505cc6921bdb6791e1cebd6ac850cc97dad13c31ff3"}, @@ -2560,6 +2481,24 @@ rich = [ {file = "rich-12.6.0-py3-none-any.whl", hash = "sha256:a4eb26484f2c82589bd9a17c73d32a010b1e29d89f1604cd9bf3a2097b81bb5e"}, {file = "rich-12.6.0.tar.gz", hash = "sha256:ba3a3775974105c221d31141f2c116f4fd65c5ceb0698657a11e9f295ec93fd0"}, ] +ruff = [ + {file = "ruff-0.0.189-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:07c947b42d3c5efc6761214acdb6b71a49b833ad9fb9b320454244a6fe01f212"}, + {file = "ruff-0.0.189-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:76e6161d021bde5738bf9d123ae445cb3a22fa60f14958ce64961d8af16141a0"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c27f51e5b48cd483459cdd1c95a6bd989adcf7653ccc440ca437f4993fe4b812"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e89f488a16ce2b21d940fc6271ed161affec788955f7b41761a9693a92e994bb"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fee593d8d470811c316ff2eb0124ac74668a3d637ab3fb237aa3fa8561fb89aa"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:bc3a73683a5b3b4b7bf951bbd4aa7d79b993c8c2e608a68de120c342ebe510f2"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e5d73877558651f48c86d958afe0f662b6c3639990c230a6b9d82ac6093484db"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d1e6e9813f59ba54e7cb6f28c1f2a9a756197f6e321bd68519afe57f8522fce"}, + {file = "ruff-0.0.189-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d177090cf03004b14814b0aad530758f5186d391250afb737570edd55beabc6"}, + {file = "ruff-0.0.189-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:48de3253856a0a85f9b53a0ca1982946c7fd343c796cdc76ece0ae359d5b71b5"}, + {file = "ruff-0.0.189-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e935bb5a213030de312ad00df477f38c78ac97af58b0e6a4ae5762705a5113da"}, + {file = "ruff-0.0.189-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bdb8173d6efff96e0cc5fe38f5fc4daa0d28fb11553482b9989d372fdafc7708"}, + {file = "ruff-0.0.189-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:14486fd8632bc4c7f926137a9c6a8c45993ff6667ddb7a88192c369c3afd86e9"}, + {file = "ruff-0.0.189-py3-none-win32.whl", hash = "sha256:e281080e2ed04f01275b3df5baa0afe2802ab145349298e24700cdd09c0afddc"}, + {file = "ruff-0.0.189-py3-none-win_amd64.whl", hash = "sha256:c552ff0b0587a5e13f935131d2a19782c0baf8b59175cf3160a76545fbdbdd76"}, + {file = "ruff-0.0.189.tar.gz", hash = "sha256:90a3031461ed83686ff78f96e58d28cdee835110c51bdfa0968a2d5892610c71"}, +] secretstorage = [ {file = "SecretStorage-3.3.1-py3-none-any.whl", hash = "sha256:422d82c36172d88d6a0ed5afdec956514b189ddbfb72fefab0c8a1cee4eaf71f"}, {file = "SecretStorage-3.3.1.tar.gz", hash = "sha256:fd666c51a6bf200643495a04abb261f83229dcb6fd8472ec393df7ffc8b6f195"}, diff --git a/pyproject.toml b/pyproject.toml index 3281441534..37b9ab3a77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,46 @@ target-version = ['py37', 'py38', 'py39', 'py310'] # https://black.readthedocs.io/en/stable/usage_and_configuration/file_collection_and_discovery.html#gitignore # Use `extend-exclude` if you want to exclude something in addition to this. +[tool.ruff] +line-length = 88 + +# See https://github.com/charliermarsh/ruff/#pycodestyle +# for error codes. The ones we ignore are: +# E731: do not assign a lambda expression, use a def +# E501: Line too long (black enforces this for us) +# +# See https://github.com/charliermarsh/ruff/#pyflakes +# F401: unused import +# F811: Redefinition of unused +# F821: Undefined name +# +# flake8-bugbear compatible checks. Its error codes are described at +# https://github.com/charliermarsh/ruff/#flake8-bugbear +# B019: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks +# B023: Functions defined inside a loop must not use variables redefined in the loop +# B024: Abstract base class with no abstract method. +ignore = [ + "B019", + "B023", + "B024", + "E501", + "E731", + "F401", + "F811", + "F821", +] +select = [ + # pycodestyle checks. + "E", + "W", + # pyflakes checks. + "F", + # flake8-bugbear checks. + "B0", + # flake8-comprehensions checks. + "C4", +] + [tool.isort] line_length = 88 sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"] @@ -274,12 +314,10 @@ all = [ ] [tool.poetry.dev-dependencies] -## We pin black so that our tests don't start failing on new releases. +# We pin black so that our tests don't start failing on new releases. isort = ">=5.10.1" black = ">=22.3.0" -flake8-comprehensions = "*" -flake8-bugbear = ">=21.3.2" -flake8 = "*" +ruff = "0.0.189" # Typechecking mypy = "*" diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index bf900645b1..f6b81013c3 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -1,9 +1,8 @@ #!/usr/bin/env bash # # Runs linting scripts over the local Synapse checkout -# isort - sorts import statements # black - opinionated code formatter -# flake8 - lints and finds mistakes +# ruff - lints and finds mistakes set -e @@ -105,6 +104,6 @@ set -x isort "${files[@]}" python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh -flake8 "${files[@]}" +ruff "${files[@]}" ./scripts-dev/check_pydantic_models.py lint mypy diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi index 24c6f3af77..196dee4461 100644 --- a/stubs/frozendict.pyi +++ b/stubs/frozendict.pyi @@ -14,6 +14,8 @@ # Stub for frozendict. +from __future__ import annotations + from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload _KT = TypeVar("_KT", bound=Hashable) # Key type. diff --git a/stubs/icu.pyi b/stubs/icu.pyi index efeda7938a..7736df8a92 100644 --- a/stubs/icu.pyi +++ b/stubs/icu.pyi @@ -14,6 +14,8 @@ # Stub for PyICU. +from __future__ import annotations + class Locale: @staticmethod def getDefault() -> Locale: ... diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index 7c399ab38d..81f581b034 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -2,6 +2,8 @@ # https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi # (from https://github.com/grantjenks/python-sortedcontainers/pull/107) +from __future__ import annotations + from typing import ( Any, Callable, diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index 403897e391..cd4c969849 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -2,6 +2,8 @@ # https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi # (from https://github.com/grantjenks/python-sortedcontainers/pull/107) +from __future__ import annotations + from typing import ( Any, Callable, diff --git a/stubs/sortedcontainers/sortedset.pyi b/stubs/sortedcontainers/sortedset.pyi index 43c860f422..d761c438f7 100644 --- a/stubs/sortedcontainers/sortedset.pyi +++ b/stubs/sortedcontainers/sortedset.pyi @@ -2,6 +2,8 @@ # https://github.com/grantjenks/python-sortedcontainers/blob/d0a225d7fd0fb4c54532b8798af3cbeebf97e2d5/sortedcontainers/sortedset.pyi # (from https://github.com/grantjenks/python-sortedcontainers/pull/107) +from __future__ import annotations + from typing import ( AbstractSet, Any, diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 01ea2b4dab..bd265de536 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse from typing import ( Any, -- cgit 1.5.1 From 5c9be9c76021ac54f425f10e8f935532d3197de5 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Thu, 22 Dec 2022 13:26:37 -0500 Subject: Check sqlite database file exists before porting. (#14692) To avoid creating an empty SQLite file if the given path is incorrect. --- changelog.d/14692.misc | 1 + synapse/_scripts/synapse_port_db.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14692.misc (limited to 'synapse') diff --git a/changelog.d/14692.misc b/changelog.d/14692.misc new file mode 100644 index 0000000000..0edac253b7 --- /dev/null +++ b/changelog.d/14692.misc @@ -0,0 +1 @@ +Check that the SQLite database file exists before porting to PostgreSQL. \ No newline at end of file diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index d850e54e17..c463b60b26 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -1307,7 +1307,7 @@ def main() -> None: sqlite_config = { "name": "sqlite3", "args": { - "database": args.sqlite_database, + "database": "file:{}?mode=rw".format(args.sqlite_database), "cp_min": 1, "cp_max": 1, "check_same_thread": False, -- cgit 1.5.1 From a52822d39c866b4d5e6d2a0176f29ae49bf3f8e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:04:50 +0000 Subject: Log to-device msgids when we return them over /sync (#14724) --- changelog.d/14724.misc | 1 + synapse/handlers/sync.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14724.misc (limited to 'synapse') diff --git a/changelog.d/14724.misc b/changelog.d/14724.misc new file mode 100644 index 0000000000..270e5ed188 --- /dev/null +++ b/changelog.d/14724.misc @@ -0,0 +1 @@ +If debug logging is enabled, log the `msgid`s of any to-device messages that are returned over `/sync`. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7d6a653747..4fa480262b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.handlers.relations import BundledAggregations +from synapse.logging import issue9533_logger from synapse.logging.context import current_context from synapse.logging.opentracing import ( SynapseTags, @@ -1623,13 +1624,18 @@ class SyncHandler: } ) - logger.debug( - "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), - since_stream_id, - stream_id, - now_token.to_device_key, - ) + if messages and issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Returning to-device messages with stream_ids (%d, %d]; now: %d;" + " msgids: %s", + since_stream_id, + stream_id, + now_token.to_device_key, + [ + message["content"].get(EventContentFields.TO_DEVICE_MSGID) + for message in messages + ], + ) sync_result_builder.now_token = now_token.copy_and_replace( StreamKeyType.TO_DEVICE, stream_id ) -- cgit 1.5.1 From 3854d0f94947ddd5a9ee98198af8d7ae839962c9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 28 Dec 2022 14:48:21 +0100 Subject: Add a `cached` helper to the module API (#14663) --- changelog.d/14663.feature | 1 + synapse/module_api/__init__.py | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14663.feature (limited to 'synapse') diff --git a/changelog.d/14663.feature b/changelog.d/14663.feature new file mode 100644 index 0000000000..b03f3ee54e --- /dev/null +++ b/changelog.d/14663.feature @@ -0,0 +1 @@ +Add a `cached` function to `synapse.module_api` that returns a decorator to cache return values of functions. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 0092a03c59..6f4a934b05 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -18,6 +18,7 @@ from typing import ( TYPE_CHECKING, Any, Callable, + Collection, Dict, Generator, Iterable, @@ -126,7 +127,7 @@ from synapse.types import ( from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import CachedFunction, cached +from synapse.util.caches.descriptors import CachedFunction, cached as _cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -136,6 +137,7 @@ if TYPE_CHECKING: T = TypeVar("T") P = ParamSpec("P") +F = TypeVar("F", bound=Callable[..., Any]) """ This package defines the 'stable' API which can be used by extension modules which @@ -185,6 +187,42 @@ class UserIpAndAgent: last_seen: int +def cached( + *, + max_entries: int = 1000, + num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, +) -> Callable[[F], CachedFunction[F]]: + """Returns a decorator that applies a memoizing cache around the function. This + decorator behaves similarly to functools.lru_cache. + + Example: + + @cached() + def foo('a', 'b'): + ... + + Added in Synapse v1.74.0. + + Args: + max_entries: The maximum number of entries in the cache. If the cache is full + and a new entry is added, the least recently accessed entry will be evicted + from the cache. + num_args: The number of positional arguments (excluding `self`) to use as cache + keys. Defaults to all named args of the function. + uncached_args: A list of argument names to not use as the cache key. (`self` is + always ignored.) Cannot be used with num_args. + + Returns: + A decorator that applies a memoizing cache around the function. + """ + return _cached( + max_entries=max_entries, + num_args=num_args, + uncached_args=uncached_args, + ) + + class ModuleApi: """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. -- cgit 1.5.1 From 044fa1a1de3c954f247a98c0ce8f734c675a5efb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Dec 2022 12:18:06 -0500 Subject: Actually use the picture_claim as configured in OIDC config. (#14751) Previously it was only using the default value ("picture") when fetching the picture from the user info. --- changelog.d/14751.bugfix | 1 + synapse/handlers/oidc.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14751.bugfix (limited to 'synapse') diff --git a/changelog.d/14751.bugfix b/changelog.d/14751.bugfix new file mode 100644 index 0000000000..56ef852288 --- /dev/null +++ b/changelog.d/14751.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.73.0 where the `picture_claim` configured under `oidc_providers` was unused (the default value of `"picture"` was used instead). diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 03de6a4ba6..23fb00c9c9 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1615,7 +1615,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) - picture = userinfo.get("picture") + picture = userinfo.get(self._config.picture_claim) return UserAttributeDict( localpart=localpart, -- cgit 1.5.1 From c4456114e1a5471bb61cb45605e782263dc8233c Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Sun, 1 Jan 2023 03:40:46 +0000 Subject: Add experimental support for MSC3391: deleting account data (#14714) --- changelog.d/14714.feature | 1 + .../complement/conf/workers-shared-extra.yaml.j2 | 2 + scripts-dev/complement.sh | 2 +- synapse/config/experimental.py | 3 + synapse/handlers/account_data.py | 111 ++++++++++- synapse/replication/http/account_data.py | 92 ++++++++- synapse/rest/client/account_data.py | 115 +++++++++++ synapse/storage/database.py | 33 +++- synapse/storage/databases/main/account_data.py | 219 +++++++++++++++++++-- 9 files changed, 547 insertions(+), 31 deletions(-) create mode 100644 changelog.d/14714.feature (limited to 'synapse') diff --git a/changelog.d/14714.feature b/changelog.d/14714.feature new file mode 100644 index 0000000000..5f3a20b7a7 --- /dev/null +++ b/changelog.d/14714.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3391](https://github.com/matrix-org/matrix-spec-proposals/pull/3391) (removing account data). \ No newline at end of file diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index ca640c343b..cb839fed07 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -102,6 +102,8 @@ experimental_features: {% endif %} # Filtering /messages by relation type. msc3874_enabled: true + # Enable removing account data support + msc3391_enabled: true server_notices: system_mxid_localpart: _server diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 8741ba3e34..51d1bac618 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -190,7 +190,7 @@ fi extra_test_args=() -test_tags="synapse_blacklist,msc3787,msc3874" +test_tags="synapse_blacklist,msc3787,msc3874,msc3391" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 573fa0386f..0f3870bfe1 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -136,3 +136,6 @@ class ExperimentalConfig(Config): # Enable room version (and thus applicable push rules from MSC3931/3932) version_id = RoomVersions.MSC1767v10.identifier KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 + + # MSC3391: Removing account data. + self.msc3391_enabled = experimental.get("msc3391_enabled", False) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index fc21d58001..aba7315cf7 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -17,10 +17,12 @@ import random from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from synapse.replication.http.account_data import ( + ReplicationAddRoomAccountDataRestServlet, ReplicationAddTagRestServlet, + ReplicationAddUserAccountDataRestServlet, + ReplicationRemoveRoomAccountDataRestServlet, ReplicationRemoveTagRestServlet, - ReplicationRoomAccountDataRestServlet, - ReplicationUserAccountDataRestServlet, + ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource from synapse.types import JsonDict, StreamKeyType, UserID @@ -41,8 +43,18 @@ class AccountDataHandler: self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() - self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) - self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_user_data_client = ( + ReplicationAddUserAccountDataRestServlet.make_client(hs) + ) + self._remove_user_data_client = ( + ReplicationRemoveUserAccountDataRestServlet.make_client(hs) + ) + self._add_room_data_client = ( + ReplicationAddRoomAccountDataRestServlet.make_client(hs) + ) + self._remove_room_data_client = ( + ReplicationRemoveRoomAccountDataRestServlet.make_client(hs) + ) self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) self._account_data_writers = hs.config.worker.writers.account_data @@ -112,7 +124,7 @@ class AccountDataHandler: return max_stream_id else: - response = await self._room_data_client( + response = await self._add_room_data_client( instance_name=random.choice(self._account_data_writers), user_id=user_id, room_id=room_id, @@ -121,15 +133,59 @@ class AccountDataHandler: ) return response["max_stream_id"] + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """ + Deletes the room account data for the given user and account data type. + + "Deleting" account data merely means setting the content of the account data + to an empty JSON object: {}. + + Args: + user_id: The user ID to remove room account data for. + room_id: The room ID to target. + account_data_type: The account data type to remove. + + Returns: + The maximum stream ID, or None if the room account data item did not exist. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + if max_stream_id is None: + # The referenced account data did not exist, so no delete occurred. + return None + + self._notifier.on_new_event( + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] + ) + + # Notify Synapse modules that the content of the type has changed to an + # empty dictionary. + await self._notify_modules(user_id, room_id, account_data_type, {}) + + return max_stream_id + else: + response = await self._remove_room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content={}, + ) + return response["max_stream_id"] + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: """Add some global account_data for a user. Args: - user_id: The user to add a tag for. + user_id: The user to add some account data for. account_data_type: The type of account_data to add. - content: A json object to associate with the tag. + content: The content json dictionary. Returns: The maximum stream ID. @@ -148,7 +204,7 @@ class AccountDataHandler: return max_stream_id else: - response = await self._user_data_client( + response = await self._add_user_data_client( instance_name=random.choice(self._account_data_writers), user_id=user_id, account_data_type=account_data_type, @@ -156,6 +212,45 @@ class AccountDataHandler: ) return response["max_stream_id"] + async def remove_account_data_for_user( + self, user_id: str, account_data_type: str + ) -> Optional[int]: + """Removes a piece of global account_data for a user. + + Args: + user_id: The user to remove account data for. + account_data_type: The type of account_data to remove. + + Returns: + The maximum stream ID, or None if the room account data item did not exist. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_account_data_for_user( + user_id, account_data_type + ) + if max_stream_id is None: + # The referenced account data did not exist, so no delete occurred. + return None + + self._notifier.on_new_event( + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] + ) + + # Notify Synapse modules that the content of the type has changed to an + # empty dictionary. + await self._notify_modules(user_id, None, account_data_type, {}) + + return max_stream_id + else: + response = await self._remove_user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content={}, + ) + return response["max_stream_id"] + async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 310f609153..0edc95977b 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): +class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint): """Add user account data on the appropriate account data worker. Request format: @@ -49,7 +49,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -73,7 +72,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): +class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint): + """Remove user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "remove_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, account_data_type: str + ) -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: + max_stream_id = await self.handler.remove_account_data_for_user( + user_id, account_data_type + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint): """Add room account data on the appropriate account data worker. Request format: @@ -94,7 +131,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -118,6 +154,44 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} +class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint): + """Remove room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "remove_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, room_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: + max_stream_id = await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + + return 200, {"max_stream_id": max_stream_id} + + class ReplicationAddTagRestServlet(ReplicationEndpoint): """Add tag on the appropriate account data worker. @@ -139,7 +213,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -186,7 +259,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] @@ -206,7 +278,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - ReplicationUserAccountDataRestServlet(hs).register(http_server) - ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddUserAccountDataRestServlet(hs).register(http_server) + ReplicationAddRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) ReplicationRemoveTagRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3391_enabled: + ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server) + ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index f13970b898..e805196fec 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -41,6 +41,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() @@ -54,6 +55,16 @@ class AccountDataServlet(RestServlet): body = parse_json_object_from_request(request) + # If experimental support for MSC3391 is enabled, then providing an empty dict + # as the value for an account data type should be functionally equivalent to + # calling the DELETE method on the same type. + if self._hs.config.experimental.msc3391_enabled: + if body == {}: + await self.handler.remove_account_data_for_user( + user_id, account_data_type + ) + return 200, {} + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -72,9 +83,48 @@ class AccountDataServlet(RestServlet): if event is None: raise NotFoundError("Account data not found") + # If experimental support for MSC3391 is enabled, then this endpoint should + # return a 404 if the content for an account data type is an empty dict. + if self._hs.config.experimental.msc3391_enabled and event == {}: + raise NotFoundError("Account data not found") + return 200, event +class UnstableAccountDataServlet(RestServlet): + """ + Contains an unstable endpoint for removing user account data, as specified by + MSC3391. If that MSC is accepted, this code should have unstable prefixes removed + and become incorporated into AccountDataServlet above. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3391/user/(?P[^/]*)" + "/account_data/(?P[^/]*)", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_DELETE( + self, + request: SynapseRequest, + user_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot delete account data for other users.") + + await self.handler.remove_account_data_for_user(user_id, account_data_type) + + return 200, {} + + class RoomAccountDataServlet(RestServlet): """ PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 @@ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() @@ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet): Codes.BAD_JSON, ) + # If experimental support for MSC3391 is enabled, then providing an empty dict + # as the value for an account data type should be functionally equivalent to + # calling the DELETE method on the same type. + if self._hs.config.experimental.msc3391_enabled: + if body == {}: + await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + return 200, {} + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet): if event is None: raise NotFoundError("Room account data not found") + # If experimental support for MSC3391 is enabled, then this endpoint should + # return a 404 if the content for an account data type is an empty dict. + if self._hs.config.experimental.msc3391_enabled and event == {}: + raise NotFoundError("Room account data not found") + return 200, event +class UnstableRoomAccountDataServlet(RestServlet): + """ + Contains an unstable endpoint for removing room account data, as specified by + MSC3391. If that MSC is accepted, this code should have unstable prefixes removed + and become incorporated into RoomAccountDataServlet above. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3391/user/(?P[^/]*)" + "/rooms/(?P[^/]*)" + "/account_data/(?P[^/]*)", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_DELETE( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot delete account data for other users.") + + if not RoomID.is_valid(room_id): + raise SynapseError( + 400, + f"{room_id} is not a valid room ID", + Codes.INVALID_PARAM, + ) + + await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) + + if hs.config.experimental.msc3391_enabled: + UnstableAccountDataServlet(hs).register(http_server) + UnstableRoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0b29e67b94..88479a16db 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1762,7 +1762,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics Returns: - A list of dictionaries. + A list of dictionaries, one per result row, each a mapping between the + column names from `retcols` and that column's value for the row. """ return await self.runInteraction( desc, @@ -1791,6 +1792,10 @@ class DatabasePool: column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + + Returns: + A list of dictionaries, one per result row, each a mapping between the + column names from `retcols` and that column's value for the row. """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1898,6 +1903,19 @@ class DatabasePool: updatevalues: Dict[str, Any], desc: str, ) -> int: + """ + Update rows in the given database table. + If the given keyvalues don't match anything, nothing will be updated. + + Args: + table: The database table to update. + keyvalues: A mapping of column name to value to match rows on. + updatevalues: A mapping of column name to value to replace in any matched rows. + desc: description of the transaction, for logging and metrics. + + Returns: + The number of rows that were updated. Will be 0 if no matching rows were found. + """ return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1909,6 +1927,19 @@ class DatabasePool: keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], ) -> int: + """ + Update rows in the given database table. + If the given keyvalues don't match anything, nothing will be updated. + + Args: + txn: The database transaction object. + table: The database table to update. + keyvalues: A mapping of column name to value to match rows on. + updatevalues: A mapping of column name to value to replace in any matched rows. + + Returns: + The number of rows that were updated. Will be 0 if no matching rows were found. + """ if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) else: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 07908c41d9..e59776f434 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - """Get all the client account_data for a user. + """ + Get all the client account_data for a user. + + If experimental MSC3391 support is enabled, any entries with an empty + content body are excluded; as this means they have been deleted. Args: user_id: The user to get the account_data for. @@ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - rows = self.db_pool.simple_select_list_txn( - txn, - "account_data", - {"user_id": user_id}, - ["account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT account_data_type, content FROM account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } - rows = self.db_pool.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id}, - ["room_id", "account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT room_id, account_data_type, content FROM room_account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -469,6 +494,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return self._account_data_id_gen.get_current_token() + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """Delete the room account data for the user of a given type. + + Args: + user_id: The user to remove account_data for. + room_id: The room ID to scope the request to. + account_data_type: The account data type to delete. + + Returns: + The maximum stream position, or None if there was no matching room account + data to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_room_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in room_account_data had its content set to '{}', + otherwise False. This informs callers of whether there actually was an + existing room account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE room_account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND room_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute( + sql, + (next_id, user_id, room_id, account_data_type), + ) + # Return true if any rows were updated. + return txn.rowcount != 0 + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_room", + _remove_account_data_for_room_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -569,6 +660,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + async def remove_account_data_for_user( + self, + user_id: str, + account_data_type: str, + ) -> Optional[int]: + """ + Delete a single piece of user account data by type. + + A "delete" is performed by updating a potentially existing row in the + "account_data" database table for (user_id, account_data_type) and + setting its content to "{}". + + Args: + user_id: The user ID to modify the account data of. + account_data_type: The type to remove. + + Returns: + The maximum stream position, or None if there was no matching account data + to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_user_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in account_data had its content set to '{}', otherwise + False. This informs callers of whether there actually was an existing + account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute(sql, (next_id, user_id, account_data_type)) + if txn.rowcount == 0: + # We didn't update any rows. This means that there was no matching room + # account data entry to delete in the first place. + return False + + # Ignored users get denormalized into a separate table as an optimisation. + if account_data_type == AccountDataTypes.IGNORED_USER_LIST: + # If this method was called with the ignored users account data type, we + # simply delete all ignored users. + + # First pull all the users that this user ignores. + previously_ignored_users = set( + self.db_pool.simple_select_onecol_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + ) + ) + + # Then delete them from the database. + self.db_pool.simple_delete_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + ) + + # Invalidate the cache for ignored users which were removed. + for ignored_user_id in previously_ignored_users: + self._invalidate_cache_and_stream( + txn, self.ignored_by, (ignored_user_id,) + ) + + # Invalidate for this user the cache tracking ignored users. + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + + return True + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_user", + _remove_account_data_for_user_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def purge_account_data_for_user(self, user_id: str) -> None: """ Removes ALL the account data for a user. -- cgit 1.5.1 From db1cfe9c80a707995fcad8f3faa839acb247068a Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 4 Jan 2023 11:49:26 +0000 Subject: Update all stream IDs after processing replication rows (#14723) This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703 --- changelog.d/14723.bugfix | 1 + synapse/replication/tcp/client.py | 3 +++ synapse/storage/_base.py | 17 ++++++++++++++++- synapse/storage/databases/main/account_data.py | 14 ++++++++++---- synapse/storage/databases/main/cache.py | 11 ++++++++--- synapse/storage/databases/main/deviceinbox.py | 7 +++++++ synapse/storage/databases/main/devices.py | 11 +++++++++-- synapse/storage/databases/main/events_worker.py | 15 ++++++++++----- synapse/storage/databases/main/presence.py | 8 +++++++- synapse/storage/databases/main/push_rule.py | 7 +++++++ synapse/storage/databases/main/pusher.py | 6 +++--- synapse/storage/databases/main/receipts.py | 7 +++++++ synapse/storage/databases/main/tags.py | 8 +++++++- 13 files changed, 95 insertions(+), 20 deletions(-) create mode 100644 changelog.d/14723.bugfix (limited to 'synapse') diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix new file mode 100644 index 0000000000..e1f89cee35 --- /dev/null +++ b/changelog.d/14723.bugfix @@ -0,0 +1 @@ +Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 658d89210d..b5e40da533 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -152,6 +152,9 @@ class ReplicationDataHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 69abf6fa87..41d9111019 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta): token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index e59776f434..86032897f5 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a58668a380..2179a8bf59 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 48a54d9cb8..713be91c5d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e..db877e3f13 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 761b15a815..d150fa8a94 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(instance_name, token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(instance_name, -token) - elif stream_name == UnPartialStatedEventStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) + super().process_replication_position(stream_name, instance_name, token) + async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d..7b60815043 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46ad..d4e4b777da 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -154,6 +154,13 @@ class PushRulesWorkerStore( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6a..7f24a3b6ec 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e06725f69c..86f5bce5f0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a3..e23c927e02 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass -- cgit 1.5.1 From 906dfaa2cf5a79ed9c18529b1a370ffd49c0204e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Jan 2023 08:26:10 -0500 Subject: Support non-OpenID compliant user info endpoints (#14753) OpenID specifies the format of the user info endpoint and some OAuth 2.0 IdPs do not follow it, e.g. NextCloud and Twitter. This adds subject_template and picture_template options to the default mapping provider for more flexibility in matching those user info responses. --- changelog.d/14753.feature | 1 + docs/usage/configuration/config_documentation.md | 18 ++++++++++++++ synapse/handlers/oidc.py | 31 ++++++++++++++++++------ 3 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14753.feature (limited to 'synapse') diff --git a/changelog.d/14753.feature b/changelog.d/14753.feature new file mode 100644 index 0000000000..38b4d6af4b --- /dev/null +++ b/changelog.d/14753.feature @@ -0,0 +1 @@ +Support non-OpenID compliant userinfo claims for subject and picture. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 67e0acc910..23f9dcbea2 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3098,10 +3098,26 @@ Options for each entry include: For the default provider, the following settings are available: + * `subject_template`: Jinja2 template for a unique identifier for the user. + Defaults to `{{ user.sub }}`, which OpenID Connect compliant providers should provide. + + This replaces and overrides `subject_claim`. + * `subject_claim`: name of the claim containing a unique identifier for the user. Defaults to 'sub', which OpenID Connect compliant providers should provide. + *Deprecated in Synapse v1.75.0.* + + * `picture_template`: Jinja2 template for an url for the user's profile picture. + Defaults to `{{ user.picture }}`, which OpenID Connect compliant providers should + provide and has to refer to a direct image file such as PNG, JPEG, or GIF image file. + + This replaces and overrides `picture_claim`. + + Currently only supported in monolithic (single-process) server configurations + where the media repository runs within the Synapse process. + * `picture_claim`: name of the claim containing an url for the user's profile picture. Defaults to 'picture', which OpenID Connect compliant providers should provide and has to refer to a direct image file such as PNG, JPEG, or GIF image file. @@ -3109,6 +3125,8 @@ Options for each entry include: Currently only supported in monolithic (single-process) server configurations where the media repository runs within the Synapse process. + *Deprecated in Synapse v1.75.0.* + * `localpart_template`: Jinja2 template for the localpart of the MXID. If this is not set, the user will be prompted to choose their own username (see the documentation for the `sso_auth_account_details.html` diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 23fb00c9c9..24e1cec5b6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1520,8 +1520,8 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: - subject_claim: str - picture_claim: str + subject_template: Template + picture_template: Template localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1540,8 +1540,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: - subject_claim = config.get("subject_claim", "sub") - picture_claim = config.get("picture_claim", "picture") + def parse_template_config_with_claim( + option_name: str, default_claim: str + ) -> Template: + template_name = f"{option_name}_template" + template = config.get(template_name) + if not template: + # Convert the legacy subject_claim into a template. + claim = config.get(f"{option_name}_claim", default_claim) + template = "{{ user.%s }}" % (claim,) + + try: + return env.from_string(template) + except Exception as e: + raise ConfigError("invalid jinja template", path=[template_name]) from e + + subject_template = parse_template_config_with_claim("subject", "sub") + picture_template = parse_template_config_with_claim("picture", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1574,8 +1589,8 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): raise ConfigError("must be a bool", path=["confirm_localpart"]) return JinjaOidcMappingConfig( - subject_claim=subject_claim, - picture_claim=picture_claim, + subject_template=subject_template, + picture_template=picture_template, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1584,7 +1599,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): ) def get_remote_user_id(self, userinfo: UserInfo) -> str: - return userinfo[self._config.subject_claim] + return self._config.subject_template.render(user=userinfo).strip() async def map_user_attributes( self, userinfo: UserInfo, token: Token, failures: int @@ -1615,7 +1630,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) - picture = userinfo.get(self._config.picture_claim) + picture = self._config.picture_template.render(user=userinfo).strip() return UserAttributeDict( localpart=localpart, -- cgit 1.5.1 From 630d0aeaf607b4016e67895d81b0402a5dfcc769 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Jan 2023 14:58:08 -0500 Subject: Support RFC7636 PKCE in the OAuth 2.0 flow. (#14750) PKCE can protect against certain attacks and is enabled by default. Support can be controlled manually by setting the pkce_method of each oidc_providers entry to 'auto' (default), 'always', or 'never'. This is required by Twitter OAuth 2.0 support. --- changelog.d/14750.feature | 1 + docs/usage/configuration/config_documentation.md | 7 +- synapse/config/oidc.py | 6 + synapse/handlers/oidc.py | 54 ++++++-- synapse/util/macaroons.py | 7 ++ tests/handlers/test_oidc.py | 152 +++++++++++++++++++++-- tests/util/test_macaroons.py | 1 + 7 files changed, 212 insertions(+), 16 deletions(-) create mode 100644 changelog.d/14750.feature (limited to 'synapse') diff --git a/changelog.d/14750.feature b/changelog.d/14750.feature new file mode 100644 index 0000000000..cfed64ee80 --- /dev/null +++ b/changelog.d/14750.feature @@ -0,0 +1 @@ +Support [RFC7636](https://datatracker.ietf.org/doc/html/rfc7636) Proof Key for Code Exchange for OAuth single sign-on. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 23f9dcbea2..ec8403c7e9 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3053,8 +3053,13 @@ Options for each entry include: values are `client_secret_basic` (default), `client_secret_post` and `none`. +* `pkce_method`: Whether to use proof key for code exchange when requesting + and exchanging the token. Valid values are: `auto`, `always`, or `never`. Defaults + to `auto`, which uses PKCE if supported during metadata discovery. Set to `always` + to force enable PKCE or `never` to force disable PKCE. + * `scopes`: list of scopes to request. This should normally include the "openid" - scope. Defaults to ["openid"]. + scope. Defaults to `["openid"]`. * `authorization_endpoint`: the oauth2 authorization endpoint. Required if provider discovery is disabled. diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 0bd83f4010..df8c422043 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -117,6 +117,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { # to avoid importing authlib here. "enum": ["client_secret_basic", "client_secret_post", "none"], }, + "pkce_method": {"type": "string", "enum": ["auto", "always", "never"]}, "scopes": {"type": "array", "items": {"type": "string"}}, "authorization_endpoint": {"type": "string"}, "token_endpoint": {"type": "string"}, @@ -289,6 +290,7 @@ def _parse_oidc_config_dict( client_secret=oidc_config.get("client_secret"), client_secret_jwt_key=client_secret_jwt_key, client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + pkce_method=oidc_config.get("pkce_method", "auto"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), token_endpoint=oidc_config.get("token_endpoint"), @@ -357,6 +359,10 @@ class OidcProviderConfig: # 'none'. client_auth_method: str + # Whether to enable PKCE when exchanging the authorization & token. + # Valid values are 'auto', 'always', and 'never'. + pkce_method: str + # list of scopes to request scopes: Collection[str] diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 24e1cec5b6..0fc829acf7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -36,6 +36,7 @@ from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template @@ -475,6 +476,16 @@ class OidcProvider: ) ) + # If PKCE support is advertised ensure the wanted method is available. + if m.get("code_challenge_methods_supported") is not None: + m.validate_code_challenge_methods_supported() + if "S256" not in m["code_challenge_methods_supported"]: + raise ValueError( + '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format( + supported=m["code_challenge_methods_supported"], + ) + ) + if m.get("response_types_supported") is not None: m.validate_response_types_supported() @@ -602,6 +613,11 @@ class OidcProvider: if self._config.jwks_uri: metadata["jwks_uri"] = self._config.jwks_uri + if self._config.pkce_method == "always": + metadata["code_challenge_methods_supported"] = ["S256"] + elif self._config.pkce_method == "never": + metadata.pop("code_challenge_methods_supported", None) + self._validate_metadata(metadata) return metadata @@ -653,7 +669,7 @@ class OidcProvider: return jwk_set - async def _exchange_code(self, code: str) -> Token: + async def _exchange_code(self, code: str, code_verifier: str) -> Token: """Exchange an authorization code for a token. This calls the ``token_endpoint`` with the authorization code we @@ -666,6 +682,7 @@ class OidcProvider: Args: code: The authorization code we got from the callback. + code_verifier: The PKCE code verifier to send, blank if unused. Returns: A dict containing various tokens. @@ -696,6 +713,8 @@ class OidcProvider: "code": code, "redirect_uri": self._callback_url, } + if code_verifier: + args["code_verifier"] = code_verifier body = urlencode(args, True) # Fill the body/headers with credentials @@ -914,11 +933,14 @@ class OidcProvider: - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string + - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported) - In addition generating a redirect URL, we are setting a cookie with - a signed macaroon token containing the state, the nonce and the - client_redirect_url params. Those are then checked when the client - comes back from the provider. + In addition to generating a redirect URL, we are setting a cookie with + a signed macaroon token containing the state, the nonce, the + client_redirect_url, and (optionally) the code_verifier params. The state, + nonce, and client_redirect_url are then checked when the client comes back + from the provider. The code_verifier is passed back to the server during + the token exchange and compared to the code_challenge sent in this request. Args: request: the incoming request from the browser. @@ -935,10 +957,25 @@ class OidcProvider: state = generate_token() nonce = generate_token() + code_verifier = "" if not client_redirect_url: client_redirect_url = b"" + metadata = await self.load_metadata() + + # Automatically enable PKCE if it is supported. + extra_grant_values = {} + if metadata.get("code_challenge_methods_supported"): + code_verifier = generate_token(48) + + # Note that we verified the server supports S256 earlier (in + # OidcProvider._validate_metadata). + extra_grant_values = { + "code_challenge_method": "S256", + "code_challenge": create_s256_code_challenge(code_verifier), + } + cookie = self._macaroon_generaton.generate_oidc_session_token( state=state, session_data=OidcSessionData( @@ -946,6 +983,7 @@ class OidcProvider: nonce=nonce, client_redirect_url=client_redirect_url.decode(), ui_auth_session_id=ui_auth_session_id or "", + code_verifier=code_verifier, ), ) @@ -966,7 +1004,6 @@ class OidcProvider: ) ) - metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") return prepare_grant_uri( authorization_endpoint, @@ -976,6 +1013,7 @@ class OidcProvider: scope=self._scopes, state=state, nonce=nonce, + **extra_grant_values, ) async def handle_oidc_callback( @@ -1003,7 +1041,9 @@ class OidcProvider: # Exchange the code with the provider try: logger.debug("Exchanging OAuth2 code for a token") - token = await self._exchange_code(code) + token = await self._exchange_code( + code, code_verifier=session_data.code_verifier + ) except OidcError as e: logger.warning("Could not exchange OAuth2 code: %s", e) self._sso_handler.render_error(request, e.error, e.error_description) diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index 5df03d3ddc..644c341e8c 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -110,6 +110,9 @@ class OidcSessionData: ui_auth_session_id: str """The session ID of the ongoing UI Auth ("" if this is a login)""" + code_verifier: str + """The random string used in the RFC7636 code challenge ("" if PKCE is not being used).""" + class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): @@ -187,6 +190,7 @@ class MacaroonGenerator: macaroon.add_first_party_caveat( f"ui_auth_session_id = {session_data.ui_auth_session_id}" ) + macaroon.add_first_party_caveat(f"code_verifier = {session_data.code_verifier}") macaroon.add_first_party_caveat(f"time < {expiry}") return macaroon.serialize() @@ -278,6 +282,7 @@ class MacaroonGenerator: v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + v.satisfy_general(lambda c: c.startswith("code_verifier = ")) satisfy_expiry(v, self._clock.time_msec) v.verify(macaroon, self._secret_key) @@ -287,11 +292,13 @@ class MacaroonGenerator: idp_id = get_value_from_macaroon(macaroon, "idp_id") client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") return OidcSessionData( nonce=nonce, idp_id=idp_id, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, + code_verifier=code_verifier, ) def _generate_base_macaroon(self, type: MacaroonType) -> pymacaroons.Macaroon: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 49a1842b5c..adddbd002f 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -396,6 +396,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(params["client_id"], [CLIENT_ID]) self.assertEqual(len(params["state"]), 1) self.assertEqual(len(params["nonce"]), 1) + self.assertNotIn("code_challenge", params) # Check what is in the cookies self.assertEqual(len(req.cookies), 2) # two cookies @@ -411,12 +412,117 @@ class OidcHandlerTestCase(HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(cookie) state = get_value_from_macaroon(macaroon, "state") nonce = get_value_from_macaroon(macaroon, "nonce") + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") redirect = get_value_from_macaroon(macaroon, "client_redirect_url") self.assertEqual(params["state"], [state]) self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(code_verifier, "") self.assertEqual(redirect, "http://client/redirect") + @override_config({"oidc_config": DEFAULT_CONFIG}) + def test_redirect_request_with_code_challenge(self) -> None: + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}): + url = urlparse( + self.get_success( + self.provider.handle_redirect_request( + req, b"http://client/redirect" + ) + ) + ) + + # Ensure the code_challenge param is added to the redirect. + params = parse_qs(url.query) + self.assertEqual(len(params["code_challenge"]), 1) + + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") + + # Ensure the code_verifier is set in the cookie. + macaroon = pymacaroons.Macaroon.deserialize(cookie) + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") + self.assertNotEqual(code_verifier, "") + + @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "always"}}) + def test_redirect_request_with_forced_code_challenge(self) -> None: + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + url = urlparse( + self.get_success( + self.provider.handle_redirect_request(req, b"http://client/redirect") + ) + ) + + # Ensure the code_challenge param is added to the redirect. + params = parse_qs(url.query) + self.assertEqual(len(params["code_challenge"]), 1) + + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") + + # Ensure the code_verifier is set in the cookie. + macaroon = pymacaroons.Macaroon.deserialize(cookie) + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") + self.assertNotEqual(code_verifier, "") + + @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "never"}}) + def test_redirect_request_with_disabled_code_challenge(self) -> None: + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + # The metadata should state that PKCE is enabled. + with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}): + url = urlparse( + self.get_success( + self.provider.handle_redirect_request( + req, b"http://client/redirect" + ) + ) + ) + + # Ensure the code_challenge param is added to the redirect. + params = parse_qs(url.query) + self.assertNotIn("code_challenge", params) + + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") + + # Ensure the code_verifier is blank in the cookie. + macaroon = pymacaroons.Macaroon.deserialize(cookie) + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") + self.assertEqual(code_verifier, "") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" @@ -601,7 +707,7 @@ class OidcHandlerTestCase(HomeserverTestCase): payload=token ) code = "code" - ret = self.get_success(self.provider._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) @@ -615,13 +721,34 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["client_secret"], [CLIENT_SECRET]) self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + # Test providing a code verifier. + code_verifier = "code_verifier" + ret = self.get_success( + self.provider._exchange_code(code, code_verifier=code_verifier) + ) + kwargs = self.fake_server.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(args["code_verifier"], [code_verifier]) + # Test error handling self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -629,7 +756,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.fake_server.post_token_handler.return_value = FakeResponse( code=500, body=b"Not JSON" ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body @@ -637,21 +766,27 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, payload={"error": "internal_server_error"} ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={} ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field self.fake_server.post_token_handler.return_value = FakeResponse.json( code=200, payload={"error": "some_error"} ) - exc = self.get_failure(self.provider._exchange_code(code), OidcError) + exc = self.get_failure( + self.provider._exchange_code(code, code_verifier=""), OidcError + ) self.assertEqual(exc.value.error, "some_error") @override_config( @@ -688,7 +823,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # timestamps. self.reactor.advance(1000) start_time = self.reactor.seconds() - ret = self.get_success(self.provider._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) self.assertEqual(ret, token) @@ -739,7 +874,7 @@ class OidcHandlerTestCase(HomeserverTestCase): payload=token ) code = "code" - ret = self.get_success(self.provider._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) self.assertEqual(ret, token) @@ -1203,6 +1338,7 @@ class OidcHandlerTestCase(HomeserverTestCase): nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, + code_verifier="", ), ) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index f68377a05a..e56ec2c860 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -92,6 +92,7 @@ class MacaroonGeneratorTestCase(TestCase): nonce="nonce", client_redirect_url="https://example.com/", ui_auth_session_id="", + code_verifier="", ) token = self.macaroon_generator.generate_oidc_session_token( state, session_data, duration_in_ms=2 * 60 * 1000 -- cgit 1.5.1 From 5e0888076fea8c70ab84114e1c261dd46330c1d6 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Mon, 9 Jan 2023 06:12:03 -0500 Subject: Disable sending confirmation email when 3pid is disabled #14682 (#14725) * Fixes #12277 :Disable sending confirmation email when 3pid is disabled * Fix test_add_email_if_disabled test case to reflect changes to enable_3pid_changes flag * Add changelog file * Rename newsfragment. Co-authored-by: Patrick Cloke --- changelog.d/14725.misc | 1 + synapse/rest/client/account.py | 5 +++++ tests/rest/client/test_account.py | 30 +++++------------------------- 3 files changed, 11 insertions(+), 25 deletions(-) create mode 100644 changelog.d/14725.misc (limited to 'synapse') diff --git a/changelog.d/14725.misc b/changelog.d/14725.misc new file mode 100644 index 0000000000..a86c4f8c05 --- /dev/null +++ b/changelog.d/14725.misc @@ -0,0 +1 @@ +Disable sending confirmation email when 3pid is disabled. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index b4b92f0c99..4373c73662 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -338,6 +338,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + if not self.hs.config.registration.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + if not self.config.email.can_verify_email: logger.warning( "Adding emails have been disabled due to lack of an email config" diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index c1a7fb2f8a..88f255c9ee 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -690,41 +690,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEqual(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - channel = self.make_request( "POST", - b"/_matrix/client/unstable/account/3pid/add", + b"/_matrix/client/unstable/account/3pid/email/requestToken", { "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, + "email": "test@example.com", + "send_attempt": 1, }, - access_token=self.user_id_tok, ) + self.assertEqual( HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] ) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_delete_email(self) -> None: """Test deleting an email from profile""" -- cgit 1.5.1 From 7e582a25f8f350df29d7d83ca902bdb522d1bbaf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 9 Jan 2023 08:43:50 -0500 Subject: Improve /sync performance of when passing filters with empty arrays. (#14786) This has two related changes: * It enables fast-path processing for an empty filter (`[]`) which was previously only used for wildcard not-filters (`["*"]`). * It special cases a `/sync` filter with no-rooms to skip all room processing, previously we would partially skip processing, but would generally still calculate intermediate values for each room which were then unused. Future changes might consider further optimizations: * Skip calculating per-room account data when all rooms are filtered (currently this is thrown away). * Make similar improvements to other endpoints which support filters. --- changelog.d/14786.feature | 1 + synapse/api/filtering.py | 13 ++++++++----- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 14 +++++++++++--- 4 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14786.feature (limited to 'synapse') diff --git a/changelog.d/14786.feature b/changelog.d/14786.feature new file mode 100644 index 0000000000..008d61ab03 --- /dev/null +++ b/changelog.d/14786.feature @@ -0,0 +1 @@ +Improve performance of `/sync` when filtering all rooms, message types, or senders. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index a9888381b4..2b5af264b4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -283,6 +283,9 @@ class FilterCollection: await self._room_filter.filter(events) ) + def blocks_all_rooms(self) -> bool: + return self._room_filter.filters_all_rooms() + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() @@ -351,13 +354,13 @@ class Filter: self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: - return "*" in self.not_types + return self.types == [] or "*" in self.not_types def filters_all_senders(self) -> bool: - return "*" in self.not_senders + return self.senders == [] or "*" in self.not_senders def filters_all_rooms(self) -> bool: - return "*" in self.not_rooms + return self.rooms == [] or "*" in self.not_rooms def _check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. @@ -450,8 +453,8 @@ class Filter: if any(map(match_func, disallowed_values)): return False - # Other the event does not match at least one of the allowed values, - # reject it. + # Otherwise if the event does not match at least one of the allowed + # values, reject it. allowed_values = getattr(self, name) if allowed_values is not None: if not any(map(match_func, allowed_values)): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 33115ce488..40f4635c4e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -275,7 +275,7 @@ class SearchHandler: ) room_ids = {r.room_id for r in rooms} - # If doing a subset of all rooms seearch, check if any of the rooms + # If doing a subset of all rooms search, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: historical_room_ids: List[str] = [] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4fa480262b..6942e06c77 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1403,11 +1403,14 @@ class SyncHandler: logger.debug("Fetching room data") - res = await self._generate_sync_entry_for_rooms( + ( + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, + newly_left_rooms, + newly_left_users, + ) = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res - _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( since_token is None and sync_config.filter_collection.blocks_all_presence() @@ -1789,6 +1792,11 @@ class SyncHandler: - newly_left_rooms - newly_left_users """ + + # If the request doesn't care about rooms then nothing to do! + if sync_result_builder.sync_config.filter_collection.blocks_all_rooms(): + return set(), set(), set(), set() + since_token = sync_result_builder.since_token # 1. Start by fetching all ephemeral events in rooms we've joined (if required). -- cgit 1.5.1 From babeeb4e7a6f5b5c643b837bf724d674805546f6 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Mon, 9 Jan 2023 09:22:02 -0500 Subject: Unescape HTML entities in oEmbed titles. (#14781) It doesn't seem valid that HTML entities should appear in the title field of oEmbed responses, but a popular WordPress plug-in seems to do it. There should not be harm in unescaping these. --- changelog.d/14781.misc | 1 + synapse/rest/media/v1/oembed.py | 15 +++++++++------ tests/rest/media/v1/test_oembed.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14781.misc (limited to 'synapse') diff --git a/changelog.d/14781.misc b/changelog.d/14781.misc new file mode 100644 index 0000000000..04f565b410 --- /dev/null +++ b/changelog.d/14781.misc @@ -0,0 +1 @@ +Unescape HTML entities in URL preview titles making use of oEmbed responses. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 827afd868d..a3738a6250 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import html import logging import urllib.parse from typing import TYPE_CHECKING, List, Optional @@ -161,7 +162,9 @@ class OEmbedProvider: title = oembed.get("title") if title and isinstance(title, str): - open_graph_response["og:title"] = title + # A common WordPress plug-in seems to incorrectly escape entities + # in the oEmbed response. + open_graph_response["og:title"] = html.unescape(title) author_name = oembed.get("author_name") if not isinstance(author_name, str): @@ -180,9 +183,9 @@ class OEmbedProvider: # Process each type separately. oembed_type = oembed.get("type") if oembed_type == "rich": - html = oembed.get("html") - if isinstance(html, str): - calc_description_and_urls(open_graph_response, html) + html_str = oembed.get("html") + if isinstance(html_str, str): + calc_description_and_urls(open_graph_response, html_str) elif oembed_type == "photo": # If this is a photo, use the full image, not the thumbnail. @@ -192,8 +195,8 @@ class OEmbedProvider: elif oembed_type == "video": open_graph_response["og:type"] = "video.other" - html = oembed.get("html") - if html and isinstance(html, str): + html_str = oembed.get("html") + if html_str and isinstance(html_str, str): calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py index 319ae8b1cc..3f7f1dbab9 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/rest/media/v1/test_oembed.py @@ -150,3 +150,13 @@ class OEmbedTests(HomeserverTestCase): result = self.parse_response({"type": "link"}) self.assertIn("og:type", result.open_graph_result) self.assertEqual(result.open_graph_result["og:type"], "website") + + def test_title_html_entities(self) -> None: + """Test HTML entities in title""" + result = self.parse_response( + {"title": "Why JSON isn’t a Good Configuration Language"} + ) + self.assertEqual( + result.open_graph_result["og:title"], + "Why JSON isn’t a Good Configuration Language", + ) -- cgit 1.5.1 From 58d2adc3da6a988452dbb9c6c4202a5ea19c4ca9 Mon Sep 17 00:00:00 2001 From: Jeyachandran Rathnam Date: Mon, 9 Jan 2023 12:17:24 -0500 Subject: Remove undocumented device from pushrules (#14727) * Remove undocumented device from pushrules * Add changelog * Update changelog.d/14727.misc * Rename 14727.misc to 14727.bugfix Co-authored-by: David Robertson --- changelog.d/14727.bugfix | 1 + synapse/push/clientformat.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14727.bugfix (limited to 'synapse') diff --git a/changelog.d/14727.bugfix b/changelog.d/14727.bugfix new file mode 100644 index 0000000000..25079496e4 --- /dev/null +++ b/changelog.d/14727.bugfix @@ -0,0 +1 @@ +Remove the unspecced `device` field from `/pushrules` responses. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 622a1e35c5..bb76c169c6 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -26,10 +26,7 @@ def format_push_rules_for_user( """Converts a list of rawrules and a enabled map into nested dictionaries to match the Matrix client-server format for push rules""" - rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = { - "global": {}, - "device": {}, - } + rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {"global": {}} rules["global"] = _add_empty_priority_class_arrays(rules["global"]) -- cgit 1.5.1 From ba4ea7d13ffae53644b206222af95a5171faa27c Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 10 Jan 2023 11:17:59 +0000 Subject: Batch up replication requests to request the resyncing of remote users's devices. (#14716) --- changelog.d/14716.misc | 1 + synapse/handlers/device.py | 124 +++++++++++++++++++++++------- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/e2e_keys.py | 93 +++++++++++++--------- synapse/handlers/federation_event.py | 2 +- synapse/replication/http/devices.py | 74 +++++++++++++++++- synapse/storage/databases/main/devices.py | 30 ++++++-- synapse/types/__init__.py | 4 + synapse/util/async_helpers.py | 55 ++++++++++++- 9 files changed, 306 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14716.misc (limited to 'synapse') diff --git a/changelog.d/14716.misc b/changelog.d/14716.misc new file mode 100644 index 0000000000..ef9522e01d --- /dev/null +++ b/changelog.d/14716.misc @@ -0,0 +1 @@ +Batch up replication requests to request the resyncing of remote users's devices. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d4750a32e6..89864e1119 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -33,6 +34,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, HttpResponseException, + InvalidAPICallError, RequestSendFailed, SynapseError, ) @@ -45,6 +47,7 @@ from synapse.types import ( JsonDict, StreamKeyType, StreamToken, + UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -893,12 +896,47 @@ class DeviceListWorkerUpdater: def __init__(self, hs: "HomeServer"): from synapse.replication.http.devices import ( + ReplicationMultiUserDevicesResyncRestServlet, ReplicationUserDevicesResyncRestServlet, ) self._user_device_resync_client = ( ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + self._multi_user_device_resync_client = ( + ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) + ) + + async def multi_user_device_resync( + self, user_ids: List[str], mark_failed_as_stale: bool = True + ) -> Dict[str, Optional[JsonDict]]: + """ + Like `user_device_resync` but operates on multiple users **from the same origin** + at once. + + Returns: + Dict from User ID to the same Dict as `user_device_resync`. + """ + # mark_failed_as_stale is not sent. Ensure this doesn't break expectations. + assert mark_failed_as_stale + + if not user_ids: + # Shortcut empty requests + return {} + + try: + return await self._multi_user_device_resync_client(user_ids=user_ids) + except SynapseError as err: + if not ( + err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED + ): + raise + + # Fall back to single requests + result: Dict[str, Optional[JsonDict]] = {} + for user_id in user_ids: + result[user_id] = await self._user_device_resync_client(user_id=user_id) + return result async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True @@ -913,8 +951,10 @@ class DeviceListWorkerUpdater: A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. """ - return await self._user_device_resync_client(user_id=user_id) + return (await self.multi_user_device_resync([user_id]))[user_id] class DeviceListUpdater(DeviceListWorkerUpdater): @@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False + async def multi_user_device_resync( + self, user_ids: List[str], mark_failed_as_stale: bool = True + ) -> Dict[str, Optional[JsonDict]]: + """ + Like `user_device_resync` but operates on multiple users **from the same origin** + at once. + + Returns: + Dict from User ID to the same Dict as `user_device_resync`. + """ + if not user_ids: + return {} + + origins = {UserID.from_string(user_id).domain for user_id in user_ids} + + if len(origins) != 1: + raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}") + + result = {} + failed = set() + # TODO(Perf): Actually batch these up + for user_id in user_ids: + user_result, user_failed = await self._user_device_resync_returning_failed( + user_id + ) + result[user_id] = user_result + if user_failed: + failed.add(user_id) + + if mark_failed_as_stale: + await self.store.mark_remote_users_device_caches_as_stale(failed) + + return result + async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True ) -> Optional[JsonDict]: + result, failed = await self._user_device_resync_returning_failed(user_id) + + if failed and mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + await self.store.mark_remote_users_device_caches_as_stale((user_id,)) + + return result + + async def _user_device_resync_returning_failed( + self, user_id: str + ) -> Tuple[Optional[JsonDict], bool]: """Fetches all devices for a user and updates the device cache with them. Args: user_id: The user's id whose device_list will be updated. - mark_failed_as_stale: Whether to mark the user's device list as stale - if the attempt to resync failed. Returns: - A dict with device info as under the "devices" in the result of this - request: - https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + - A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. + - True iff the resync failed and the device list should be marked as stale. """ logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) @@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): try: result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - - return None + return None, True except (RequestSendFailed, HttpResponseException) as e: logger.warning( "Failed to handle device list update for %s: %s", @@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater): e, ) - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. - return None + return None, True except FederationDeniedError as e: set_tag("error", True) log_kv({"reason": "FederationDeniedError"}) logger.info(e) - return None + return None, False except Exception as e: set_tag("error", True) log_kv( @@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): ) logger.exception("Failed to handle device list update for %s", user_id) - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - - return None + return None, True log_kv({"result": result}) stream_id = result["stream_id"] devices = result["devices"] @@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # point. self._seen_updates[user_id] = {stream_id} - return result + return result, False async def process_cross_signing_key_update( self, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 75e89850f5..00c403db49 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -195,7 +195,7 @@ class DeviceMessageHandler: sender_user_id, unknown_devices, ) - await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,)) # Immediately attempt a resync in the background run_in_background(self._user_device_resync, user_id=sender_user_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5fe102e2f2..d2188ca08f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -36,8 +36,8 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, delay_cancellation +from synapse.util import json_decoder +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.cancellation import cancellable from synapse.util.retryutils import NotRetryingDestination @@ -238,24 +238,28 @@ class E2eKeysHandler: # Now fetch any devices that we don't have in our cache # TODO It might make sense to propagate cancellations into the # deferreds which are querying remote homeservers. - await make_deferred_yieldable( - delay_cancellation( - defer.gatherResults( - [ - run_in_background( - self._query_devices_for_destination, - results, - cross_signing_keys, - failures, - destination, - queries, - timeout, - ) - for destination, queries in remote_queries_not_in_cache.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + logger.debug( + "%d destinations to query devices for", len(remote_queries_not_in_cache) + ) + + async def _query( + destination_queries: Tuple[str, Dict[str, Iterable[str]]] + ) -> None: + destination, queries = destination_queries + return await self._query_devices_for_destination( + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, ) + + await concurrently_execute( + _query, + remote_queries_not_in_cache.items(), + 10, + delay_cancellation=True, ) ret = {"device_keys": results, "failures": failures} @@ -300,28 +304,41 @@ class E2eKeysHandler: # queries. We use the more efficient batched query_client_keys for all # remaining users user_ids_updated = [] - for (user_id, device_list) in destination_query.items(): - if user_id in user_ids_updated: - continue - if device_list: - continue + # Perform a user device resync for each user only once and only as long as: + # - they have an empty device_list + # - they are in some rooms that this server can see + users_to_resync_devices = { + user_id + for (user_id, device_list) in destination_query.items() + if (not device_list) and (await self.store.get_rooms_for_user(user_id)) + } - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - continue + logger.debug( + "%d users to resync devices for from destination %s", + len(users_to_resync_devices), + destination, + ) - # We've decided we're sharing a room with this user and should - # probably be tracking their device lists. However, we haven't - # done an initial sync on the device list so we do it now. - try: - resync_results = ( - await self.device_handler.device_list_updater.user_device_resync( - user_id - ) + try: + user_resync_results = ( + await self.device_handler.device_list_updater.multi_user_device_resync( + list(users_to_resync_devices) ) + ) + for user_id in users_to_resync_devices: + resync_results = user_resync_results[user_id] + if resync_results is None: - raise ValueError("Device resync failed") + # TODO: It's weird that we'll store a failure against a + # destination, yet continue processing users from that + # destination. + # We might want to consider changing this, but for now + # I'm leaving it as I found it. + failures[destination] = _exception_to_failure( + ValueError(f"Device resync failed for {user_id!r}") + ) + continue # Add the device keys to the results. user_devices = resync_results["devices"] @@ -339,8 +356,8 @@ class E2eKeysHandler: if self_signing_key: cross_signing_keys["self_signing_keys"][user_id] = self_signing_key - except Exception as e: - failures[destination] = _exception_to_failure(e) + except Exception as e: + failures[destination] = _exception_to_failure(e) if len(destination_query) == len(user_ids_updated): # We've updated all the users in the query and we do not need to diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 31df7f55cc..6df000faaf 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1423,7 +1423,7 @@ class FederationEventHandler: """ try: - await self._store.mark_remote_user_device_cache_as_stale(sender) + await self._store.mark_remote_users_device_caches_as_stale((sender,)) # Immediately attempt a resync in the background if self._config.worker.worker_app: diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 7c4941c3d3..ea5c08e6cf 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,12 +13,13 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request +from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices +class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): + """Ask master to resync the device list for multiple users from the same + remote server by contacting their server. + + This must happen on master so that the results can be correctly cached in + the database and streamed to workers. + + Request format: + + POST /_synapse/replication/multi_user_device_resync + + { + "user_ids": ["@alice:example.org", "@bob:example.org", ...] + } + + Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id` + response, but there is a map from user ID to response, e.g.: + + { + "@alice:example.org": { + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": { ... }, + "device_display_name": "Alice's Mobile Phone" + } + ] + }, + ... + } + """ + + NAME = "multi_user_device_resync" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override] + return {"user_ids": user_ids} + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: + content = parse_json_object_from_request(request) + user_ids: List[str] = content["user_ids"] + + logger.info("Resync for %r", user_ids) + span = active_span() + if span: + span.set_tag("user_ids", f"{user_ids!r}") + + multi_user_devices = await self.device_list_updater.multi_user_device_resync( + user_ids + ) + + return 200, multi_user_devices + + class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): """Ask master to upload keys for the user and send them out over federation to update other servers. @@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server) ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index db877e3f13..b067664473 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import ( AbstractStreamIdTracker, StreamIdGenerator, ) -from synapse.types import JsonDict, get_verify_key_from_cross_signing_key +from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return {row["user_id"] for row in rows} - async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: + async def mark_remote_users_device_caches_as_stale( + self, user_ids: StrCollection + ) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - await self.db_pool.simple_upsert( - table="device_lists_remote_resync", - keyvalues={"user_id": user_id}, - values={}, - insertion_values={"added_ts": self._clock.time_msec()}, - desc="mark_remote_user_device_cache_as_stale", + + def _mark_remote_users_device_caches_as_stale_txn( + txn: LoggingTransaction, + ) -> None: + # TODO add insertion_values support to simple_upsert_many and use + # that! + for user_id in user_ids: + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, + values={}, + insertion_values={"added_ts": self._clock.time_msec()}, + ) + + await self.db_pool.runInteraction( + "mark_remote_users_device_caches_as_stale", + _mark_remote_users_device_caches_as_stale_txn, ) async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None: diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index f2d436ddc3..0c725eb967 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object +# Collection[str] that does not include str itself; str being a Sequence[str] +# is very misleading and results in bugs. +StrCollection = Union[Tuple[str, ...], List[str], Set[str]] + # Note that this seems to require inheriting *directly* from Interface in order # for mypy-zope to realize it is an interface. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index d24c4f68c4..01e3cd46f6 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -205,7 +205,10 @@ T = TypeVar("T") async def concurrently_execute( - func: Callable[[T], Any], args: Iterable[T], limit: int + func: Callable[[T], Any], + args: Iterable[T], + limit: int, + delay_cancellation: bool = False, ) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -215,6 +218,8 @@ async def concurrently_execute( args: List of arguments to pass to func, each invocation of func gets a single argument. limit: Maximum number of conccurent executions. + delay_cancellation: Whether to delay cancellation until after the invocations + have finished. Returns: None, when all function invocations have finished. The return values @@ -233,9 +238,16 @@ async def concurrently_execute( # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - await yieldable_gather_results( - _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) - ) + if delay_cancellation: + await yieldable_gather_results_delaying_cancellation( + _concurrently_execute_inner, + (value for value in itertools.islice(it, limit)), + ) + else: + await yieldable_gather_results( + _concurrently_execute_inner, + (value for value in itertools.islice(it, limit)), + ) P = ParamSpec("P") @@ -292,6 +304,41 @@ async def yieldable_gather_results( raise dfe.subFailure.value from None +async def yieldable_gather_results_delaying_cancellation( + func: Callable[Concatenate[T, P], Awaitable[R]], + iter: Iterable[T], + *args: P.args, + **kwargs: P.kwargs, +) -> List[R]: + """Executes the function with each argument concurrently. + Cancellation is delayed until after all the results have been gathered. + + See `yieldable_gather_results`. + + Args: + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first + argument to the function + *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func + + Returns + A list containing the results of the function + """ + try: + return await make_deferred_yieldable( + delay_cancellation( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] + consumeErrors=True, + ) + ) + ) + except defer.FirstError as dfe: + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None + + T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") -- cgit 1.5.1 From 06ab64f201dffcb93b826546e20be53cc712c8b8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 10 Jan 2023 16:31:28 +0000 Subject: Implement MSC3925: changes to bundling of edits (#14811) Two parts to this: * Bundle the whole of the replacement with any edited events. This is backwards-compatible so I haven't put it behind a flag. * Optionally, inhibit server-side replacement of edited events. This has scope to break things, so it is currently disabled by default. --- changelog.d/14811.feature | 1 + synapse/config/experimental.py | 3 + synapse/events/utils.py | 31 ++++-- synapse/server.py | 2 +- tests/rest/client/test_relations.py | 185 +++++++++++++++++++++++++----------- 5 files changed, 159 insertions(+), 63 deletions(-) create mode 100644 changelog.d/14811.feature (limited to 'synapse') diff --git a/changelog.d/14811.feature b/changelog.d/14811.feature new file mode 100644 index 0000000000..87542835c3 --- /dev/null +++ b/changelog.d/14811.feature @@ -0,0 +1 @@ +Per [MSC3925](https://github.com/matrix-org/matrix-spec-proposals/pull/3925), bundle the whole of the replacement with any edited events, and optionally inhibit server-side replacement. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 0f3870bfe1..a8b2db372d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -139,3 +139,6 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) + + # MSC3925: do not replace events with their edits + self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 13fa93afb8..ae57a4df5e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -403,6 +403,14 @@ class EventClientSerializer: clients. """ + def __init__(self, inhibit_replacement_via_edits: bool = False): + """ + Args: + inhibit_replacement_via_edits: If this is set to True, then events are + never replaced by their edits. + """ + self._inhibit_replacement_via_edits = inhibit_replacement_via_edits + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -422,6 +430,8 @@ class EventClientSerializer: into the event. apply_edits: Whether the content of the event should be modified to reflect any replacement in `bundle_aggregations[].replace`. + See also the `inhibit_replacement_via_edits` constructor arg: if that is + set to True, then this argument is ignored. Returns: The serialized event """ @@ -495,7 +505,8 @@ class EventClientSerializer: again for additional events in a recursive manner. serialized_event: The serialized event which may be modified. apply_edits: Whether the content of the event should be modified to reflect - any replacement in `aggregations.replace`. + any replacement in `aggregations.replace` (subject to the + `inhibit_replacement_via_edits` constructor arg). """ # We have already checked that aggregations exist for this event. @@ -518,15 +529,21 @@ class EventClientSerializer: if event_aggregations.replace: # If there is an edit, optionally apply it to the event. edit = event_aggregations.replace - if apply_edits: + if apply_edits and not self._inhibit_replacement_via_edits: self._apply_edit(event, serialized_event, edit) # Include information about it in the relations dict. - serialized_aggregations[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - "origin_server_ts": edit.origin_server_ts, - "sender": edit.sender, - } + # + # Matrix spec v1.5 (https://spec.matrix.org/v1.5/client-server-api/#server-side-aggregation-of-mreplace-relationships) + # said that we should only include the `event_id`, `origin_server_ts` and + # `sender` of the edit; however MSC3925 proposes extending it to the whole + # of the edit, which is what we do here. + serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event( + edit, + time_now, + config=config, + apply_edits=False, + ) # Include any threaded replies to this event. if event_aggregations.thread: diff --git a/synapse/server.py b/synapse/server.py index 5baae2325e..f4ab94c4f3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -743,7 +743,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self.config.experimental.msc3925_inhibit_edit) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index b86f341ff5..c8a6911d5e 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -30,6 +30,7 @@ from tests import unittest from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event +from tests.unittest import override_config class BaseRelationsTestCase(unittest.HomeserverTestCase): @@ -355,30 +356,67 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def _assert_edit_bundle( + self, event_json: JsonDict, edit_event_id: str, edit_event_content: JsonDict + ) -> None: + """ + Assert that the given event has a correctly-serialised edit event in its + bundled aggregations + + Args: + event_json: the serialised event to be checked + edit_event_id: the ID of the edit event that we expect to be bundled + edit_event_content: the content of that event, excluding the 'm.relates_to` + property + """ + relations_dict = event_json["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in [ + "event_id", + "sender", + "origin_server_ts", + "content", + "type", + "unsigned", + ]: + self.assertIn(key, m_replace_dict) + + expected_edit_content = { + "m.relates_to": { + "event_id": event_json["event_id"], + "rel_type": "m.replace", + } + } + expected_edit_content.update(edit_event_content) + + self.assert_dict( + { + "event_id": edit_event_id, + "sender": self.user_id, + "content": expected_edit_content, + "type": "m.room.message", + }, + m_replace_dict, + ) + def test_edit(self) -> None: """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": new_body, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + content=edit_event_content, ) edit_event_id = channel.json_body["event_id"] - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - # /event should return the *original* event channel = self.make_request( "GET", @@ -389,7 +427,7 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual( channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"} ) - assert_bundle(channel.json_body) + self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) # Request the room messages. channel = self.make_request( @@ -398,7 +436,11 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + self._assert_edit_bundle( + self._find_event_in_chunk(channel.json_body["chunk"]), + edit_event_id, + edit_event_content, + ) # Request the room context. # /context should return the edited event. @@ -408,7 +450,9 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content + ) self.assertEqual(channel.json_body["event"]["content"], new_body) # Request sync, but limit the timeline so it becomes limited (and includes @@ -420,7 +464,11 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + self._assert_edit_bundle( + self._find_event_in_chunk(room_timeline["events"]), + edit_event_id, + edit_event_content, + ) # Request search. channel = self.make_request( @@ -437,7 +485,45 @@ class RelationsTestCase(BaseRelationsTestCase): "results" ] ] - assert_bundle(self._find_event_in_chunk(chunk)) + self._assert_edit_bundle( + self._find_event_in_chunk(chunk), + edit_event_id, + edit_event_content, + ) + + @override_config({"experimental_features": {"msc3925_inhibit_edit": True}}) + def test_edit_inhibit_replace(self) -> None: + """ + If msc3925_inhibit_edit is enabled, then the original event should not be + replaced. + """ + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": new_body, + } + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content=edit_event_content, + ) + edit_event_id = channel.json_body["event_id"] + + # /context should return the *original* event. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"} + ) + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content + ) def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who @@ -455,10 +541,15 @@ class RelationsTestCase(BaseRelationsTestCase): ) new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": new_body, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + content=edit_event_content, ) edit_event_id = channel.json_body["event_id"] @@ -480,16 +571,8 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["event"]["content"], new_body) - - relations_dict = channel.json_body["event"]["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content ) def test_edit_reply(self) -> None: @@ -502,11 +585,15 @@ class RelationsTestCase(BaseRelationsTestCase): ) reply = channel.json_body["event_id"] - new_body = {"msgtype": "m.text", "body": "I've been edited!"} + edit_event_content = { + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "I've been edited!"}, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + content=edit_event_content, parent_id=reply, ) edit_event_id = channel.json_body["event_id"] @@ -549,28 +636,22 @@ class RelationsTestCase(BaseRelationsTestCase): # We expect that the edit relation appears in the unsigned relations # section. - relations_dict = result_event_dict["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict, desc) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict, desc) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + self._assert_edit_bundle( + result_event_dict, edit_event_id, edit_event_content ) def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} + edit_event_content = { + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + } channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": new_body, - }, + content=edit_event_content, ) edit_event_id = channel.json_body["event_id"] @@ -599,8 +680,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) # The relations information should not include the edit to the edit. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) # /context should return the event updated for the *first* edit # (The edit to the edit should be ignored.) @@ -611,13 +691,8 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["event"]["content"], new_body) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + self._assert_edit_bundle( + channel.json_body["event"], edit_event_id, edit_event_content ) # Directly requesting the edit should not have the edit to the edit applied. -- cgit 1.5.1 From 73f097888eedaad05eda6b2453b6558158c0b032 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 11 Jan 2023 13:00:38 +0100 Subject: Add listener `health` (#14747) Fixes: #8780 --- changelog.d/14747.feature | 1 + docs/usage/configuration/config_documentation.md | 6 ++++++ synapse/app/generic_worker.py | 3 +++ synapse/app/homeserver.py | 3 +++ 4 files changed, 13 insertions(+) create mode 100644 changelog.d/14747.feature (limited to 'synapse') diff --git a/changelog.d/14747.feature b/changelog.d/14747.feature new file mode 100644 index 0000000000..0b8066159c --- /dev/null +++ b/changelog.d/14747.feature @@ -0,0 +1 @@ +Add a dedicated listener configuration for `health` endpoint. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index a355eef529..294dd6eddd 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -480,6 +480,12 @@ Valid resource names are: * `static`: static resources under synapse/static (/_matrix/static). (Mostly useful for 'fallback authentication'.) +* `health`: the [health check endpoint](../../reverse_proxy.md#health-check-endpoint). This endpoint + is by default active for all other resources and does not have to be activated separately. + This is only useful if you want to use the health endpoint explicitly on a dedicated port or + for [workers](../../workers.md) and containers without listener e.g. + [application services](../../workers.md#notifying-application-services). + Example configuration #1: ```yaml listeners: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index bcc8abe20c..8108b1e98f 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -199,6 +199,9 @@ class GenericWorkerServer(HomeServer): "A 'media' listener is configured but the media" " repository is disabled. Ignoring." ) + elif name == "health": + # Skip loading, health resource is always included + continue if name == "openid" and "federation" not in res.names: # Only load the openid resource separately if federation resource diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b9be558c7e..6176a70eb2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -96,6 +96,9 @@ class SynapseHomeServer(HomeServer): # Skip loading openid resource if federation is defined # since federation resource will include openid continue + if name == "health": + # Skip loading, health resource is always included + continue resources.update(self._configure_named_resource(name, res.compress)) additional_resources = listener_config.http_options.additional_resources -- cgit 1.5.1 From 3952297f6f39906a65e70bce7becc1acd300a287 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Jan 2023 07:16:41 -0500 Subject: Calculate rooms changed for device lists to work. (#14810) Back-out some changes from 7e582a25f8f350df29d7d83ca902bdb522d1bbaf (#14786) which skipped necessary logic to calculate device lists properly. --- changelog.d/14810.bugfix | 1 + synapse/api/filtering.py | 3 --- synapse/handlers/sync.py | 4 ---- 3 files changed, 1 insertion(+), 7 deletions(-) create mode 100644 changelog.d/14810.bugfix (limited to 'synapse') diff --git a/changelog.d/14810.bugfix b/changelog.d/14810.bugfix new file mode 100644 index 0000000000..379bfccffa --- /dev/null +++ b/changelog.d/14810.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.75.0rc1 where device lists could be miscalculated with some sync filters. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 2b5af264b4..4cf8f0cc8e 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -283,9 +283,6 @@ class FilterCollection: await self._room_filter.filter(events) ) - def blocks_all_rooms(self) -> bool: - return self._room_filter.filters_all_rooms() - def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6942e06c77..20ee2f203a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1793,10 +1793,6 @@ class SyncHandler: - newly_left_users """ - # If the request doesn't care about rooms then nothing to do! - if sync_result_builder.sync_config.filter_collection.blocks_all_rooms(): - return set(), set(), set(), set() - since_token = sync_result_builder.since_token # 1. Start by fetching all ephemeral events in rooms we've joined (if required). -- cgit 1.5.1 From d6bda5adddd863409961dbafcd018356c213610e Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 11 Jan 2023 12:29:13 +0000 Subject: Add index to improve performance of the `/timestamp_to_event` endpoint used for jumping to a specific date in the timeline of a room. (#14799) --- changelog.d/14799.bugfix | 1 + synapse/storage/databases/main/events_bg_updates.py | 12 ++++++++++++ .../main/delta/73/24_events_jump_to_date_index.sql | 17 +++++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 changelog.d/14799.bugfix create mode 100644 synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql (limited to 'synapse') diff --git a/changelog.d/14799.bugfix b/changelog.d/14799.bugfix new file mode 100644 index 0000000000..dc867bd93a --- /dev/null +++ b/changelog.d/14799.bugfix @@ -0,0 +1 @@ +Add index to improve performance of the `/timestamp_to_event` endpoint used for jumping to a specific date in the timeline of a room. \ No newline at end of file diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 9e31798ab1..b9d3c36d60 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -69,6 +69,8 @@ class _BackgroundUpdates: EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" + EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" + @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -260,6 +262,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._background_events_populate_state_key_rejections, ) + # Add an index that would be useful for jumping to date using + # get_event_id_for_timestamp. + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.EVENTS_JUMP_TO_DATE_INDEX, + index_name="events_jump_to_date_idx", + table="events", + columns=["room_id", "origin_server_ts"], + where_clause="NOT outlier", + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql b/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql new file mode 100644 index 0000000000..67059909a1 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/24_events_jump_to_date_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7324, 'events_jump_to_date_index', '{}'); -- cgit 1.5.1 From 5172c8c403d94ea5f184abc8b3c37dbd19a849bc Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 11 Jan 2023 13:21:53 +0000 Subject: Faster remote room joins (worker mode): do not populate external hosts-in-room cache when sending events as this requires blocking for full state. (#14749) Signed-off-by: Olivier Wilkinson (reivilibre) Co-authored-by: Sean Quah --- changelog.d/14749.misc | 1 + synapse/handlers/message.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14749.misc (limited to 'synapse') diff --git a/changelog.d/14749.misc b/changelog.d/14749.misc new file mode 100644 index 0000000000..ff81325225 --- /dev/null +++ b/changelog.d/14749.misc @@ -0,0 +1 @@ +Faster remote room joins (worker mode): do not populate external hosts-in-room cache when sending events as this requires blocking for full state. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 88fc51a4c9..3278a695ed 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1531,12 +1531,23 @@ class EventCreationHandler: external federation senders don't have to recalculate it themselves. """ - for event, _ in events_and_context: - if not self._external_cache.is_enabled(): - return + if not self._external_cache.is_enabled(): + return - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None + + for event, event_context in events_and_context: + if event_context.partial_state: + # To populate the cache for a partial-state event, we either have to + # block until full state, which the code below does, or change the + # meaning of cache values to be the list of hosts to which we plan to + # send events and calculate that instead. + # + # The federation senders don't use the external cache when sending + # events in partial-state rooms anyway, so let's not bother populating + # the cache. + continue # We actually store two mappings, event ID -> prev state group, # state group -> joined hosts, which is much more space efficient -- cgit 1.5.1 From dd9e71dc7fa91b81adfaaf8669aaf7ee976ffcd7 Mon Sep 17 00:00:00 2001 From: Emelie Graven Date: Wed, 11 Jan 2023 19:41:52 +0100 Subject: Add `set_displayname` to the module API (#14629) --- changelog.d/14629.feature | 1 + synapse/module_api/__init__.py | 27 +++++++++++++++++++++++++++ tests/module_api/test_api.py | 18 ++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 changelog.d/14629.feature (limited to 'synapse') diff --git a/changelog.d/14629.feature b/changelog.d/14629.feature new file mode 100644 index 0000000000..78f5fc2403 --- /dev/null +++ b/changelog.d/14629.feature @@ -0,0 +1 @@ +Adds a `set_displayname()` method to the module API for setting a user's display name. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6f4a934b05..6153a48257 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1585,6 +1585,33 @@ class ModuleApi: return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None) + async def set_displayname( + self, + user_id: UserID, + new_displayname: str, + deactivation: bool = False, + ) -> None: + """Sets a user's display name. + + Added in Synapse v1.76.0. + + Args: + user_id: + The user whose display name is to be changed. + new_displayname: + The new display name to give the user. + deactivation: + Whether this change was made while deactivating the user. + """ + requester = create_requester(user_id) + await self._hs.get_profile_handler().set_displayname( + target_user=user_id, + requester=requester, + new_displayname=new_displayname, + by_admin=True, + deactivation=deactivation, + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index b0f3f4374d..9919938e80 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -110,6 +110,24 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(found_user.user_id.to_string(), user_id) self.assertIdentical(found_user.is_admin, True) + def test_can_set_displayname(self): + localpart = "alice_wants_a_new_displayname" + user_id = self.register_user( + localpart, "1234", displayname="Alice", admin=False + ) + found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + + self.get_success( + self.module_api.set_displayname( + found_userinfo.user_id, "Bob", deactivation=False + ) + ) + found_profile = self.get_success( + self.module_api.get_profile_for_user(localpart) + ) + + self.assertEqual(found_profile.display_name, "Bob") + def test_get_userinfo_by_id(self): user_id = self.register_user("alice", "1234") found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) -- cgit 1.5.1 From 84ce93c12f921063bb6c59400fcf95649a1b7f45 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2023 10:29:09 +0000 Subject: Fix race calling `/members?at=` (#14817) Fixes #14814 --- changelog.d/14817.bugfix | 1 + synapse/storage/databases/main/stream.py | 65 +++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14817.bugfix (limited to 'synapse') diff --git a/changelog.d/14817.bugfix b/changelog.d/14817.bugfix new file mode 100644 index 0000000000..bb5da79268 --- /dev/null +++ b/changelog.d/14817.bugfix @@ -0,0 +1 @@ +Fix race where calling `/members` or `/state` with an `at` parameter could fail for newly created rooms, when using multiple workers. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index cc27ec3804..63d8350530 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -801,13 +801,66 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): before this stream ordering. """ - last_row = await self.get_room_event_before_stream_ordering( - room_id=room_id, - stream_ordering=end_token.stream, + def get_last_event_in_room_before_stream_ordering_txn( + txn: LoggingTransaction, + ) -> Optional[str]: + # We need to handle the fact that the stream tokens can be vector + # clocks. We do this by getting all rows between the minimum and + # maximum stream ordering in the token, plus one row less than the + # minimum stream ordering. We then filter the results against the + # token and return the first row that matches. + + sql = """ + SELECT * FROM ( + SELECT instance_name, stream_ordering, topological_ordering, event_id + FROM events + LEFT JOIN rejections USING (event_id) + WHERE room_id = ? + AND ? < stream_ordering AND stream_ordering <= ? + AND NOT outlier + AND rejections.event_id IS NULL + ORDER BY stream_ordering DESC + ) AS a + UNION + SELECT * FROM ( + SELECT instance_name, stream_ordering, topological_ordering, event_id + FROM events + LEFT JOIN rejections USING (event_id) + WHERE room_id = ? + AND stream_ordering <= ? + AND NOT outlier + AND rejections.event_id IS NULL + ORDER BY stream_ordering DESC + LIMIT 1 + ) AS b + """ + txn.execute( + sql, + ( + room_id, + end_token.stream, + end_token.get_max_stream_pos(), + room_id, + end_token.stream, + ), + ) + + for instance_name, stream_ordering, topological_ordering, event_id in txn: + if _filter_results( + lower_token=None, + upper_token=end_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ): + return event_id + + return None + + return await self.db_pool.runInteraction( + "get_last_event_in_room_before_stream_ordering", + get_last_event_in_room_before_stream_ordering_txn, ) - if last_row: - return last_row[2] - return None async def get_current_room_stream_token_for_room_id( self, room_id: str -- cgit 1.5.1 From b50c008453001aee8dd7dbd6f36ec32039e6ce76 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2023 10:52:07 +0000 Subject: Re-enable some linting (#14821) * Re-enable some linting * Newsfile * Remove comment --- changelog.d/14821.misc | 1 + pyproject.toml | 8 -------- stubs/sortedcontainers/sortedlist.pyi | 1 - stubs/sortedcontainers/sortedset.pyi | 2 -- stubs/synapse/synapse_rust/push.pyi | 2 +- synapse/config/_base.pyi | 10 ++++------ tests/storage/test_event_push_actions.py | 6 +++--- 7 files changed, 9 insertions(+), 21 deletions(-) create mode 100644 changelog.d/14821.misc (limited to 'synapse') diff --git a/changelog.d/14821.misc b/changelog.d/14821.misc new file mode 100644 index 0000000000..99e4e5e8a1 --- /dev/null +++ b/changelog.d/14821.misc @@ -0,0 +1 @@ +Re-enable some linting that was disabled when we switched to ruff. diff --git a/pyproject.toml b/pyproject.toml index 740d33066e..10d50ddb45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,11 +48,6 @@ line-length = 88 # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) # -# See https://github.com/charliermarsh/ruff/#pyflakes -# F401: unused import -# F811: Redefinition of unused -# F821: Undefined name -# # flake8-bugbear compatible checks. Its error codes are described at # https://github.com/charliermarsh/ruff/#flake8-bugbear # B019: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks @@ -64,9 +59,6 @@ ignore = [ "B024", "E501", "E731", - "F401", - "F811", - "F821", ] select = [ # pycodestyle checks. diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index cd4c969849..1fe1a136f1 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -7,7 +7,6 @@ from __future__ import annotations from typing import ( Any, Callable, - Generic, Iterable, Iterator, List, diff --git a/stubs/sortedcontainers/sortedset.pyi b/stubs/sortedcontainers/sortedset.pyi index d761c438f7..6db11eacbe 100644 --- a/stubs/sortedcontainers/sortedset.pyi +++ b/stubs/sortedcontainers/sortedset.pyi @@ -5,10 +5,8 @@ from __future__ import annotations from typing import ( - AbstractSet, Any, Callable, - Generic, Hashable, Iterable, Iterator, diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index b91f2edd7b..373b40740b 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union from synapse.types import JsonDict diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index bd265de536..b5cec132b4 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,5 +1,3 @@ -from __future__ import annotations - import argparse from typing import ( Any, @@ -20,7 +18,7 @@ from typing import ( import jinja2 -from synapse.config import ( +from synapse.config import ( # noqa: F401 account_validity, api, appservice, @@ -169,7 +167,7 @@ class RootConfig: self, section_name: Literal["caches"] ) -> cache.CacheConfig: ... @overload - def reload_config_section(self, section_name: str) -> Config: ... + def reload_config_section(self, section_name: str) -> "Config": ... class Config: root: RootConfig @@ -202,9 +200,9 @@ def find_config_files(search_paths: List[str]) -> List[str]: ... class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... - def should_handle(self, instance_name: str, key: str) -> bool: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... # noqa: F811 class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): - def get_instance(self, key: str) -> str: ... + def get_instance(self, key: str) -> str: ... # noqa: F811 def read_file(file_path: Any, config_path: Iterable[str]) -> str: ... diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5fa8bd2d98..76c06a9d1e 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -154,7 +154,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): # Create a user to receive notifications and send receipts. user_id, token, _, other_token, room_id = self._create_users_and_room() - last_event_id: str + last_event_id = "" def _assert_counts(notif_count: int, highlight_count: int) -> None: counts = self.get_success( @@ -289,7 +289,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): user_id, token, _, other_token, room_id = self._create_users_and_room() thread_id: str - last_event_id: str + last_event_id = "" def _assert_counts( notif_count: int, @@ -471,7 +471,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): user_id, token, _, other_token, room_id = self._create_users_and_room() thread_id: str - last_event_id: str + last_event_id = "" def _assert_counts( notif_count: int, -- cgit 1.5.1 From 772e8c23856e27960caba4dd87af42401b6c0cac Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 13 Jan 2023 00:16:21 +0000 Subject: Fix stack overflow in `_PerHostRatelimiter` due to synchronous requests (#14812) When there are many synchronous requests waiting on a `_PerHostRatelimiter`, each request will be started recursively just after the previous request has completed. Under the right conditions, this leads to stack exhaustion. A common way for requests to become synchronous is when the remote client disconnects early, because the homeserver is overloaded and slow to respond. Avoid stack exhaustion under these conditions by deferring subsequent requests until the next reactor tick. Fixes #14480. Signed-off-by: Sean Quah --- changelog.d/14812.bugfix | 1 + synapse/rest/client/register.py | 1 + synapse/server.py | 1 + synapse/util/ratelimitutils.py | 34 +++++++++++++++++++++-------- tests/util/test_ratelimitutils.py | 45 ++++++++++++++++++++++++++++++++++++--- 5 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 changelog.d/14812.bugfix (limited to 'synapse') diff --git a/changelog.d/14812.bugfix b/changelog.d/14812.bugfix new file mode 100644 index 0000000000..94e0d70cbc --- /dev/null +++ b/changelog.d/14812.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would exhaust the stack when processing many federation requests where the remote homeserver has disconencted early. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 3cb1e7e375..be696c304b 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -310,6 +310,7 @@ class UsernameAvailabilityRestServlet(RestServlet): self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( + hs.get_reactor(), hs.get_clock(), FederationRatelimitSettings( # Time window of 2s diff --git a/synapse/server.py b/synapse/server.py index f4ab94c4f3..c8752baa5a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -768,6 +768,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: return FederationRateLimiter( + self.get_reactor(), self.get_clock(), config=self.config.ratelimiting.rc_federation, metrics_name="federation_servlets", diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 2aceb1a47f..bd72947bfe 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -34,6 +34,7 @@ from prometheus_client.core import Counter from typing_extensions import ContextManager from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import FederationRatelimitSettings @@ -146,12 +147,14 @@ class FederationRateLimiter: def __init__( self, + reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: + reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -163,7 +166,7 @@ class FederationRateLimiter: def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter( - clock=clock, config=config, metrics_name=metrics_name + reactor=reactor, clock=clock, config=config, metrics_name=metrics_name ) self.ratelimiters: DefaultDict[ @@ -194,12 +197,14 @@ class FederationRateLimiter: class _PerHostRatelimiter: def __init__( self, + reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: + reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -207,6 +212,7 @@ class _PerHostRatelimiter: for this rate limiter. from the rest in the metrics """ + self.reactor = reactor self.clock = clock self.metrics_name = metrics_name @@ -364,12 +370,22 @@ class _PerHostRatelimiter: def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) - self.current_processing.discard(request_id) - try: - # start processing the next item on the queue. - _, deferred = self.ready_request_queue.popitem(last=False) - with PreserveLoggingContext(): - deferred.callback(None) - except KeyError: - pass + # When requests complete synchronously, we will recursively start the next + # request in the queue. To avoid stack exhaustion, we defer starting the next + # request until the next reactor tick. + + def start_next_request() -> None: + # We only remove the completed request from the list when we're about to + # start the next one, otherwise we can allow extra requests through. + self.current_processing.discard(request_id) + try: + # start processing the next item on the queue. + _, deferred = self.ready_request_queue.popitem(last=False) + + with PreserveLoggingContext(): + deferred.callback(None) + except KeyError: + pass + + self.reactor.callLater(0.0, start_next_request) diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 5b327b390e..2f3ea15b96 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.config.homeserver import HomeServerConfig @@ -29,7 +30,7 @@ class FederationRateLimiterTestCase(TestCase): """A simple test with the default values""" reactor, clock = get_clock() rc_config = build_rc_config() - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -39,7 +40,7 @@ class FederationRateLimiterTestCase(TestCase): """Test what happens when we hit the concurrent limit""" reactor, clock = get_clock() rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -57,6 +58,7 @@ class FederationRateLimiterTestCase(TestCase): # ... until we complete an earlier request cm2.__exit__(None, None, None) + reactor.advance(0.0) self.successResultOf(d3) def test_sleep_limit(self) -> None: @@ -65,7 +67,7 @@ class FederationRateLimiterTestCase(TestCase): rc_config = build_rc_config( {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}} ) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -81,6 +83,43 @@ class FederationRateLimiterTestCase(TestCase): sleep_time = _await_resolution(reactor, d3) self.assertAlmostEqual(sleep_time, 500, places=3) + def test_lots_of_queued_things(self) -> None: + """Tests lots of synchronous things queued up behind a slow thing. + + The stack should *not* explode when the slow thing completes. + """ + reactor, clock = get_clock() + rc_config = build_rc_config( + { + "rc_federation": { + "sleep_limit": 1000000000, # never sleep + "reject_limit": 1000000000, # never reject requests + "concurrent": 1, + } + } + ) + ratelimiter = FederationRateLimiter(reactor, clock, rc_config) + + with ratelimiter.ratelimit("testhost") as d: + # shouldn't block + self.successResultOf(d) + + async def task() -> None: + with ratelimiter.ratelimit("testhost") as d: + await d + + for _ in range(1, 100): + defer.ensureDeferred(task()) + + last_task = defer.ensureDeferred(task()) + + # Upon exiting the context manager, all the synchronous things will resume. + # If a stack overflow occurs, the final task will not complete. + + # Wait for all the things to complete. + reactor.advance(0.0) + self.successResultOf(last_task) + def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float: """advance the clock until the deferred completes. -- cgit 1.5.1 From 3a125625e70634075cc4d965a01309af56748eb2 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 13 Jan 2023 12:37:28 +0000 Subject: Add some clarifying comments and refactor a portion of the `Keyring` class for readability (#14804) --- changelog.d/14804.misc | 1 + synapse/crypto/keyring.py | 61 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 44 insertions(+), 18 deletions(-) create mode 100644 changelog.d/14804.misc (limited to 'synapse') diff --git a/changelog.d/14804.misc b/changelog.d/14804.misc new file mode 100644 index 0000000000..24302332bd --- /dev/null +++ b/changelog.d/14804.misc @@ -0,0 +1 @@ +Add some clarifying comments and refactor a portion of the `Keyring` class for readability. \ No newline at end of file diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 69310d9035..86cd4af9bd 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -154,17 +154,21 @@ class Keyring: if key_fetchers is None: key_fetchers = ( + # Fetch keys from the database. StoreKeyFetcher(hs), + # Fetch keys from a configured Perspectives server. PerspectivesKeyFetcher(hs), + # Fetch keys from the origin server directly. ServerKeyFetcher(hs), ) self._key_fetchers = key_fetchers - self._server_queue: BatchingQueue[ + self._fetch_keys_queue: BatchingQueue[ _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] ] = BatchingQueue( "keyring_server", clock=hs.get_clock(), + # The method called to fetch each key process_batch_callback=self._inner_fetch_key_requests, ) @@ -287,7 +291,7 @@ class Keyring: minimum_valid_until_ts=verify_request.minimum_valid_until_ts, key_ids=list(key_ids_to_find), ) - found_keys_by_server = await self._server_queue.add_to_queue( + found_keys_by_server = await self._fetch_keys_queue.add_to_queue( key_request, key=verify_request.server_name ) @@ -352,7 +356,17 @@ class Keyring: async def _inner_fetch_key_requests( self, requests: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: - """Processing function for the queue of `_FetchKeyRequest`.""" + """Processing function for the queue of `_FetchKeyRequest`. + + Takes a list of key fetch requests, de-duplicates them and then carries out + each request by invoking self._inner_fetch_key_request. + + Args: + requests: A list of requests for homeserver verify keys. + + Returns: + {server name: {key id: fetch key result}} + """ logger.debug("Starting fetch for %s", requests) @@ -397,8 +411,23 @@ class Keyring: async def _inner_fetch_key_request( self, verify_request: _FetchKeyRequest ) -> Dict[str, FetchKeyResult]: - """Attempt to fetch the given key by calling each key fetcher one by - one. + """Attempt to fetch the given key by calling each key fetcher one by one. + + If a key is found, check whether its `valid_until_ts` attribute satisfies the + `minimum_valid_until_ts` attribute of the `verify_request`. If it does, we + refrain from asking subsequent fetchers for that key. + + Even if the above check fails, we still return the found key - the caller may + still find the invalid key result useful. In this case, we continue to ask + subsequent fetchers for the invalid key, in case they return a valid result + for it. This can happen when fetching a stale key result from the database, + before querying the origin server for an up-to-date result. + + Args: + verify_request: The request for a verify key. Can include multiple key IDs. + + Returns: + A map of {key_id: the key fetch result}. """ logger.debug("Starting fetch for %s", verify_request) @@ -420,26 +449,22 @@ class Keyring: if not key: continue - # If we already have a result for the given key ID we keep the + # If we already have a result for the given key ID, we keep the # one with the highest `valid_until_ts`. existing_key = found_keys.get(key_id) - if existing_key: - if key.valid_until_ts <= existing_key.valid_until_ts: - continue + if existing_key and existing_key.valid_until_ts > key.valid_until_ts: + continue + + # Check if this key's expiry timestamp is valid for the verify request. + if key.valid_until_ts >= verify_request.minimum_valid_until_ts: + # Stop looking for this key from subsequent fetchers. + missing_key_ids.discard(key_id) - # We always store the returned key even if it doesn't the + # We always store the returned key even if it doesn't meet the # `minimum_valid_until_ts` requirement, as some verification # requests may still be able to be satisfied by it. - # - # We still keep looking for the key from other fetchers in that - # case though. found_keys[key_id] = key - if key.valid_until_ts < verify_request.minimum_valid_until_ts: - continue - - missing_key_ids.discard(key_id) - return found_keys -- cgit 1.5.1 From 8d5325ec0c04c3b0f08e0c5b4a26c5939d9db7f1 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 13 Jan 2023 15:17:03 +0100 Subject: Drop unused table `presence` (#14825) --- changelog.d/14825.misc | 1 + scripts-dev/database-save.sh | 1 - .../storage/schema/main/delta/73/25drop_presence.sql | 17 +++++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14825.misc create mode 100644 synapse/storage/schema/main/delta/73/25drop_presence.sql (limited to 'synapse') diff --git a/changelog.d/14825.misc b/changelog.d/14825.misc new file mode 100644 index 0000000000..64312ac09e --- /dev/null +++ b/changelog.d/14825.misc @@ -0,0 +1 @@ +Drop unused table `presence`. \ No newline at end of file diff --git a/scripts-dev/database-save.sh b/scripts-dev/database-save.sh index 040c8a4943..91674027ae 100755 --- a/scripts-dev/database-save.sh +++ b/scripts-dev/database-save.sh @@ -11,6 +11,5 @@ sqlite3 "$1" <<'EOF' >table-save.sql .dump users .dump access_tokens -.dump presence .dump profiles EOF diff --git a/synapse/storage/schema/main/delta/73/25drop_presence.sql b/synapse/storage/schema/main/delta/73/25drop_presence.sql new file mode 100644 index 0000000000..9f6ffa20b6 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/25drop_presence.sql @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this table is unused +DROP TABLE presence; -- cgit 1.5.1 From 73ff493dfba63541a09eaf08587eb8bbd3330967 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2023 14:57:43 +0000 Subject: Merge account data streams (#14826) --- changelog.d/14826.misc | 1 + docs/upgrade.md | 12 ++++++ synapse/api/constants.py | 1 + synapse/handlers/account_data.py | 7 +++- synapse/handlers/initial_sync.py | 8 ++-- synapse/handlers/sync.py | 11 +++++- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/handler.py | 3 +- synapse/replication/tcp/streams/__init__.py | 3 -- synapse/replication/tcp/streams/_base.py | 49 ++++++++++++----------- synapse/storage/databases/main/account_data.py | 6 +-- synapse/storage/databases/main/tags.py | 54 +++++++------------------- 12 files changed, 75 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14826.misc (limited to 'synapse') diff --git a/changelog.d/14826.misc b/changelog.d/14826.misc new file mode 100644 index 0000000000..9ebedcf51e --- /dev/null +++ b/changelog.d/14826.misc @@ -0,0 +1 @@ +Merge tag and normal account data replication streams. diff --git a/docs/upgrade.md b/docs/upgrade.md index c4bc5889a9..8a76172e43 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,18 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.76.0 + +## Changes to the account data replication streams + +Synapse has changed the format of the account data replication streams (between +workers). This is a forwards- and backwards-incompatible change: v1.75 workers +cannot process account data replicated by v1.76 workers, and vice versa. + +Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data +replication will resume as normal. + + # Upgrading to v1.74.0 ## Unicode support in user search diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6a5e7171da..6432d32d83 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -249,6 +249,7 @@ class RoomEncryptionAlgorithms: class AccountDataTypes: DIRECT: Final = "m.direct" IGNORED_USER_LIST: Final = "m.ignored_user_list" + TAG: Final = "m.tag" class HistoryVisibility: diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index aba7315cf7..834006356a 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -16,6 +16,7 @@ import logging import random from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( ReplicationAddRoomAccountDataRestServlet, ReplicationAddTagRestServlet, @@ -335,7 +336,11 @@ class AccountDataEventSource(EventSource[int, JsonDict]): for room_id, room_tags in tags.items(): results.append( - {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + { + "type": AccountDataTypes.TAG, + "content": {"tags": room_tags}, + "room_id": room_id, + } ) ( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9c335e6863..8c2260ad7d 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -239,7 +239,7 @@ class InitialSyncHandler: tags = tags_by_room.get(event.room_id) if tags: account_data_events.append( - {"type": "m.tag", "content": {"tags": tags}} + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} ) account_data = account_data_by_room.get(event.room_id, {}) @@ -326,7 +326,9 @@ class InitialSyncHandler: account_data_events = [] tags = await self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) + account_data_events.append( + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} + ) account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 20ee2f203a..78d488f2b1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,7 +31,12 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -2331,7 +2336,9 @@ class SyncHandler: account_data_events = [] if tags is not None: - account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) + account_data_events.append( + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} + ) for account_data_type, content in account_data.items(): account_data_events.append( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b5e40da533..7263bb2796 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import ( PushersStream, PushRulesStream, ReceiptsStream, - TagAccountDataStream, ToDeviceStream, TypingStream, UnPartialStatedEventStream, @@ -168,7 +167,7 @@ class ReplicationDataHandler: self.notifier.on_new_event( StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] ) - elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): + elif stream_name in AccountDataStream.NAME: self.notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0f166d16aa..d03a53d764 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import ( PresenceStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -145,7 +144,7 @@ class ReplicationCommandHandler: continue - if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + if isinstance(stream, AccountDataStream): # Only add AccountDataStream and TagAccountDataStream as a source on the # instance in charge of account_data persistence. if hs.get_instance_name() in hs.config.worker.writers.account_data: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 110f10aab9..a7eadfa3c9 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import ( PushRulesStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, UserSignatureStream, @@ -62,7 +61,6 @@ STREAMS_MAP = { DeviceListsStream, ToDeviceStream, FederationStream, - TagAccountDataStream, AccountDataStream, UserSignatureStream, UnPartialStatedRoomStream, @@ -83,7 +81,6 @@ __all__ = [ "CachesStream", "DeviceListsStream", "ToDeviceStream", - "TagAccountDataStream", "AccountDataStream", "UserSignatureStream", "UnPartialStatedRoomStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e01155ad59..fbf78da9c2 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,8 +28,8 @@ from typing import ( import attr +from synapse.api.constants import AccountDataTypes from synapse.replication.http.streams import ReplicationGetStreamUpdates -from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -495,27 +495,6 @@ class ToDeviceStream(Stream): ) -class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class TagAccountDataStreamRow: - user_id: str - room_id: str - data: JsonDict - - NAME = "tag_account_data" - ROW_TYPE = TagAccountDataStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_max_account_data_stream_id), - store.get_all_updated_tags, - ) - - class AccountDataStream(Stream): """Global or per room account data was changed""" @@ -560,6 +539,19 @@ class AccountDataStream(Stream): to_token = room_results[-1][0] limited = True + tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags( + instance_name, + from_token, + to_token, + limit, + ) + + # again, if the tag results hit the limit, limit the global results to + # the same stream token. + if tags_limited: + to_token = tag_to_token + limited = True + # convert the global results to the right format, and limit them to the to_token # at the same time global_rows = ( @@ -568,11 +560,16 @@ class AccountDataStream(Stream): if stream_id <= to_token ) - # we know that the room_results are already limited to `to_token` so no need - # for a check on `stream_id` here. room_rows = ( (stream_id, (user_id, room_id, account_data_type)) for stream_id, user_id, room_id, account_data_type in room_results + if stream_id <= to_token + ) + + tag_rows = ( + (stream_id, (user_id, room_id, AccountDataTypes.TAG)) + for stream_id, user_id, room_id in tags + if stream_id <= to_token ) # We need to return a sorted list, so merge them together. @@ -582,7 +579,9 @@ class AccountDataStream(Stream): # leading to a comparison between the data tuples. The comparison could # fail due to attempting to compare the `room_id` which results in a # `TypeError` from comparing a `str` vs `None`. - updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + updates = list( + heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) + ) return updates, to_token, limited diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 86032897f5..881d7089db 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,7 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( DatabasePool, @@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def process_replication_position( self, stream_name: str, instance_name: str, token: int ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index e23c927e02..d5500cdd47 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -17,7 +17,8 @@ import logging from typing import Any, Dict, Iterable, List, Tuple, cast -from synapse.replication.tcp.streams import TagAccountDataStream +from synapse.api.constants import AccountDataTypes +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, str, str]], int, bool]: """Get updates for tags replication stream. Args: @@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The token returned can be used in a subsequent call to this function to get further updatees. - The updates are a list of 2-tuples of stream ID and the row data + The updates are a list of tuples of stream ID, user ID and room ID """ if last_id == current_id: @@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore): "get_all_updated_tags", get_all_updated_tags_txn ) - def get_tag_content( - txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]] - ) -> List[Tuple[int, Tuple[str, str, str]]]: - sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" - results = [] - for stream_id, user_id, room_id in tag_ids: - txn.execute(sql, (user_id, room_id)) - tags = [] - for tag, content in txn: - tags.append(json_encoder.encode(tag) + ":" + content) - tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, (user_id, room_id, tag_json))) - - return results - - batch_size = 50 - results = [] - for i in range(0, len(tag_ids), batch_size): - tags = await self.db_pool.runInteraction( - "get_all_updated_tag_content", - get_tag_content, - tag_ids[i : i + batch_size], - ) - results.extend(tags) - limited = False upto_token = current_id - if len(results) >= limit: - upto_token = results[-1][0] + if len(tag_ids) >= limit: + upto_token = tag_ids[-1][0] limited = True - return results, upto_token, limited + return tag_ids, upto_token, limited async def get_updated_tags( self, user_id: str, stream_id: int @@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) + if row.data_type == AccountDataTypes.TAG: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed( + row.user_id, token + ) super().process_replication_rows(stream_name, instance_name, token, rows) - def process_replication_position( - self, stream_name: str, instance_name: str, token: int - ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - super().process_replication_position(stream_name, instance_name, token) - class TagsStore(TagsWorkerStore): pass -- cgit 1.5.1 From 52ae80dd1afd9bb5b4cf2bb79297e1590f92cacb Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 13 Jan 2023 17:58:53 +0000 Subject: Use stable identifiers for faster joins (#14832) * Use new query param when requesting a partial join * Read new query param when serving partial join * Provide new field names when serving partial joins * Read new field names from partial join response * Changelog --- changelog.d/14832.misc | 1 + synapse/federation/federation_server.py | 2 + synapse/federation/transport/client.py | 18 ++++++ synapse/federation/transport/server/federation.py | 13 +++- tests/federation/test_federation_server.py | 2 +- tests/federation/transport/test_client.py | 77 +++++++++++++++++------ 6 files changed, 89 insertions(+), 24 deletions(-) create mode 100644 changelog.d/14832.misc (limited to 'synapse') diff --git a/changelog.d/14832.misc b/changelog.d/14832.misc new file mode 100644 index 0000000000..61e7401e43 --- /dev/null +++ b/changelog.d/14832.misc @@ -0,0 +1 @@ +Faster joins: use stable identifiers from [MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706). diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index bb20af6e91..c65dbf87fb 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -725,10 +725,12 @@ class FederationServer(FederationBase): "state": [p.get_pdu_json(time_now) for p in state_events], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events], "org.matrix.msc3706.partial_state": caller_supports_partial_state, + "members_omitted": caller_supports_partial_state, } if servers_in_room is not None: resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room) + resp["servers_in_room"] = list(servers_in_room) return resp diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 77f1f39cac..c8471d4cf7 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -357,6 +357,7 @@ class TransportLayerClient: if self._faster_joins_enabled: # lazy-load state on join query_params["org.matrix.msc3706.partial_state"] = "true" + query_params["omit_members"] = "true" return await self.client.put_json( destination=destination, @@ -909,6 +910,14 @@ class SendJoinParser(ByteParser[SendJoinResponse]): use_float="True", ) ) + # The stable field name comes last, so it "wins" if the fields disagree + self._coros.append( + ijson.items_coro( + _partial_state_parser(self._response), + "members_omitted", + use_float="True", + ) + ) self._coros.append( ijson.items_coro( @@ -918,6 +927,15 @@ class SendJoinParser(ByteParser[SendJoinResponse]): ) ) + # Again, stable field name comes last + self._coros.append( + ijson.items_coro( + _servers_in_room_parser(self._response), + "servers_in_room", + use_float="True", + ) + ) + def write(self, data: bytes) -> int: for c in self._coros: c.send(data) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 53e77b4bb6..c0a700905b 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -437,9 +437,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): partial_state = False if self._msc3706_enabled: - partial_state = parse_boolean_from_args( - query, "org.matrix.msc3706.partial_state", default=False - ) + # The stable query parameter wins, if it disagrees with the unstable + # parameter for some reason. + stable_param = parse_boolean_from_args(query, "omit_members", default=None) + if stable_param is not None: + partial_state = stable_param + else: + partial_state = parse_boolean_from_args( + query, "org.matrix.msc3706.partial_state", default=False + ) + result = await self.handler.on_send_join_request( origin, content, room_id, caller_supports_partial_state=partial_state ) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 177e5b5afc..27770304be 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -224,7 +224,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) channel = self.make_signed_federation_request( "PUT", - f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", + f"/_matrix/federation/v2/send_join/{self._room_id}/x?omit_members=true", content=join_event_dict, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index b84c74fc0e..c90635e0a0 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -13,12 +13,14 @@ # limitations under the License. import json +from typing import List, Optional from unittest.mock import Mock import ijson.common from synapse.api.room_versions import RoomVersions from synapse.federation.transport.client import SendJoinParser +from synapse.types import JsonDict from synapse.util import ExceptionBundle from tests.unittest import TestCase @@ -71,33 +73,68 @@ class SendJoinParserTestCase(TestCase): def test_partial_state(self) -> None: """Check that the partial_state flag is correctly parsed""" - parser = SendJoinParser(RoomVersions.V1, False) - response = { - "org.matrix.msc3706.partial_state": True, - } - serialised_response = json.dumps(response).encode() + def parse(response: JsonDict) -> bool: + parser = SendJoinParser(RoomVersions.V1, False) + serialised_response = json.dumps(response).encode() - # Send data to the parser - parser.write(serialised_response) + # Send data to the parser + parser.write(serialised_response) - # Retrieve and check the parsed SendJoinResponse - parsed_response = parser.finish() - self.assertTrue(parsed_response.partial_state) + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + return parsed_response.partial_state - def test_servers_in_room(self) -> None: - """Check that the servers_in_room field is correctly parsed""" - parser = SendJoinParser(RoomVersions.V1, False) - response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + self.assertTrue(parse({"members_omitted": True})) + self.assertTrue(parse({"org.matrix.msc3706.partial_state": True})) - serialised_response = json.dumps(response).encode() + self.assertFalse(parse({"members_omitted": False})) + self.assertFalse(parse({"org.matrix.msc3706.partial_state": False})) - # Send data to the parser - parser.write(serialised_response) + # If there's a conflict, the stable field wins. + self.assertTrue( + parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False}) + ) + self.assertFalse( + parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True}) + ) - # Retrieve and check the parsed SendJoinResponse - parsed_response = parser.finish() - self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + + def parse(response: JsonDict) -> Optional[List[str]]: + parser = SendJoinParser(RoomVersions.V1, False) + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + return parsed_response.servers_in_room + + self.assertEqual( + parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}), + ["hs1", "hs2"], + ) + self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"]) + + # If both are provided, the stable identifier should win + self.assertEqual( + parse( + { + "org.matrix.msc3706.servers_in_room": ["old"], + "servers_in_room": ["new"], + } + ), + ["new"], + ) + + # And lastly, we should be able to tell if neither field was present. + self.assertEqual( + parse({}), + None, + ) def test_errors_closing_coroutines(self) -> None: """Check we close all coroutines, even if closing the first raises an Exception. -- cgit 1.5.1 From 54cd90ea60610a6dc24a291dd0cad4ce9bea8728 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 13 Jan 2023 19:32:10 +0000 Subject: Implement MSC3890: Remotely silence local notifications (#14775) --- changelog.d/14775.feature | 1 + docker/complement/conf/workers-shared-extra.yaml.j2 | 2 ++ scripts-dev/complement.sh | 2 +- synapse/config/experimental.py | 15 +++++++++++++++ synapse/handlers/device.py | 11 ++++++++++- 5 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14775.feature (limited to 'synapse') diff --git a/changelog.d/14775.feature b/changelog.d/14775.feature new file mode 100644 index 0000000000..7b7ee42cac --- /dev/null +++ b/changelog.d/14775.feature @@ -0,0 +1 @@ +Implement support for MSC3890: Remotely silence local notifications. \ No newline at end of file diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index cb839fed07..1170694df5 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -102,6 +102,8 @@ experimental_features: {% endif %} # Filtering /messages by relation type. msc3874_enabled: true + # Enable deleting device-specific notification settings stored in account data + msc3890_enabled: true # Enable removing account data support msc3391_enabled: true diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 51d1bac618..7c48d8bccb 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -190,7 +190,7 @@ fi extra_test_args=() -test_tags="synapse_blacklist,msc3787,msc3874,msc3391" +test_tags="synapse_blacklist,msc3787,msc3874,msc3890,msc3391" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index a8b2db372d..72a17e0616 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -17,6 +17,7 @@ from typing import Any, Optional import attr from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions +from synapse.config import ConfigError from synapse.config._base import Config from synapse.types import JsonDict @@ -93,6 +94,9 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) + # MSC3391: Removing account data. + self.msc3391_enabled = experimental.get("msc3391_enabled", False) + # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) @@ -127,6 +131,17 @@ class ExperimentalConfig(Config): "msc3886_endpoint", None ) + # MSC3890: Remotely silence local notifications + # Note: This option requires "experimental_features.msc3391_enabled" to be + # set to "true", in order to communicate account data deletions to clients. + self.msc3890_enabled: bool = experimental.get("msc3890_enabled", False) + if self.msc3890_enabled and not self.msc3391_enabled: + raise ConfigError( + "Option 'experimental_features.msc3391' must be set to 'true' to " + "enable 'experimental_features.msc3890'. MSC3391 functionality is " + "required to communicate account data deletions to clients." + ) + # MSC3912: Relation-based redactions. self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 89864e1119..0640ea79a0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -346,6 +346,7 @@ class DeviceHandler(DeviceWorkerHandler): super().__init__(hs) self.federation_sender = hs.get_federation_sender() + self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() self.device_list_updater = DeviceListUpdater(hs, self) @@ -502,7 +503,7 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - # Delete access tokens and e2e keys for each device. Not optimised as it is not + # Delete data specific to each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: await self._auth_handler.delete_access_tokens_for_user( @@ -512,6 +513,14 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id ) + if self.hs.config.experimental.msc3890_enabled: + # Remove any local notification settings for this device in accordance + # with MSC3890. + await self._account_data_handler.remove_account_data_for_user( + user_id, + f"org.matrix.msc3890.local_notification_settings.{device_id}", + ) + await self.notify_device_update(user_id, device_ids) async def update_device(self, user_id: str, device_id: str, content: dict) -> None: -- cgit 1.5.1 From 85a7a201fa460c227562111fba4d3d6aef681e23 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 16 Jan 2023 12:40:25 +0000 Subject: Also use stable name in SendJoinResponse struct (#14841) * Also use stable name in SendJoinResponse struct follow-up to #14832 * Changelog * Fix a rename I missed * Run black * Update synapse/federation/federation_client.py Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14841.misc | 1 + synapse/federation/federation_client.py | 6 +++--- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/client.py | 16 +++++++++------- tests/federation/transport/test_client.py | 6 +++--- 5 files changed, 17 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14841.misc (limited to 'synapse') diff --git a/changelog.d/14841.misc b/changelog.d/14841.misc new file mode 100644 index 0000000000..61e7401e43 --- /dev/null +++ b/changelog.d/14841.misc @@ -0,0 +1 @@ +Faster joins: use stable identifiers from [MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 137cfb3346..b7002e8a6c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1142,9 +1142,9 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) - if response.partial_state and not response.servers_in_room: + if response.members_omitted and not response.servers_in_room: raise InvalidResponseError( - "partial_state was set, but no servers were listed in the room" + "members_omitted was set, but no servers were listed in the room" ) return SendJoinResult( @@ -1152,7 +1152,7 @@ class FederationClient(FederationBase): state=signed_state, auth_chain=signed_auth, origin=destination, - partial_state=response.partial_state, + partial_state=response.members_omitted, servers_in_room=response.servers_in_room or [], ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c65dbf87fb..3197939a36 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1502,7 +1502,7 @@ def _get_event_ids_for_partial_state_join( prev_state_ids: StateMap[str], summary: Dict[str, MemberSummary], ) -> Collection[str]: - """Calculate state to be retuned in a partial_state send_join + """Calculate state to be returned in a partial_state send_join Args: join_event: the join event being send_joined diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index c8471d4cf7..5ec651400a 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -795,7 +795,7 @@ class SendJoinResponse: event: Optional[EventBase] = None # The room state is incomplete - partial_state: bool = False + members_omitted: bool = False # List of servers in the room servers_in_room: Optional[List[str]] = None @@ -835,16 +835,18 @@ def _event_list_parser( @ijson.coroutine -def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]: +def _members_omitted_parser(response: SendJoinResponse) -> Generator[None, Any, None]: """Helper function for use with `ijson.items_coro` - Parses the partial_state field in send_join responses + Parses the members_omitted field in send_join responses """ while True: val = yield if not isinstance(val, bool): - raise TypeError("partial_state must be a boolean") - response.partial_state = val + raise TypeError( + "members_omitted (formerly org.matrix.msc370c.partial_state) must be a boolean" + ) + response.members_omitted = val @ijson.coroutine @@ -905,7 +907,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]): if not v1_api: self._coros.append( ijson.items_coro( - _partial_state_parser(self._response), + _members_omitted_parser(self._response), "org.matrix.msc3706.partial_state", use_float="True", ) @@ -913,7 +915,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]): # The stable field name comes last, so it "wins" if the fields disagree self._coros.append( ijson.items_coro( - _partial_state_parser(self._response), + _members_omitted_parser(self._response), "members_omitted", use_float="True", ) diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index c90635e0a0..3d61b1e8a9 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -68,11 +68,11 @@ class SendJoinParserTestCase(TestCase): self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) - self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertFalse(parsed_response.members_omitted, parsed_response) self.assertEqual(parsed_response.servers_in_room, None, parsed_response) def test_partial_state(self) -> None: - """Check that the partial_state flag is correctly parsed""" + """Check that the members_omitted flag is correctly parsed""" def parse(response: JsonDict) -> bool: parser = SendJoinParser(RoomVersions.V1, False) @@ -83,7 +83,7 @@ class SendJoinParserTestCase(TestCase): # Retrieve and check the parsed SendJoinResponse parsed_response = parser.finish() - return parsed_response.partial_state + return parsed_response.members_omitted self.assertTrue(parse({"members_omitted": True})) self.assertTrue(parse({"org.matrix.msc3706.partial_state": True})) -- cgit 1.5.1 From a302d3ecf75493f84fc5be616fee7d199ed12394 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 16 Jan 2023 13:16:19 +0000 Subject: Remove unnecessary reactor reference from `_PerHostRatelimiter` (#14842) Fix up #14812 to avoid introducing a reference to the reactor. Signed-off-by: Sean Quah --- changelog.d/14842.bugfix | 1 + synapse/rest/client/register.py | 1 - synapse/server.py | 1 - synapse/util/ratelimitutils.py | 10 ++-------- tests/util/test_ratelimitutils.py | 8 ++++---- 5 files changed, 7 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14842.bugfix (limited to 'synapse') diff --git a/changelog.d/14842.bugfix b/changelog.d/14842.bugfix new file mode 100644 index 0000000000..94e0d70cbc --- /dev/null +++ b/changelog.d/14842.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would exhaust the stack when processing many federation requests where the remote homeserver has disconencted early. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index be696c304b..3cb1e7e375 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -310,7 +310,6 @@ class UsernameAvailabilityRestServlet(RestServlet): self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( - hs.get_reactor(), hs.get_clock(), FederationRatelimitSettings( # Time window of 2s diff --git a/synapse/server.py b/synapse/server.py index c8752baa5a..f4ab94c4f3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -768,7 +768,6 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: return FederationRateLimiter( - self.get_reactor(), self.get_clock(), config=self.config.ratelimiting.rc_federation, metrics_name="federation_servlets", diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index bd72947bfe..f262bf95a0 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -34,7 +34,6 @@ from prometheus_client.core import Counter from typing_extensions import ContextManager from twisted.internet import defer -from twisted.internet.interfaces import IReactorTime from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import FederationRatelimitSettings @@ -147,14 +146,12 @@ class FederationRateLimiter: def __init__( self, - reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: - reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -166,7 +163,7 @@ class FederationRateLimiter: def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter( - reactor=reactor, clock=clock, config=config, metrics_name=metrics_name + clock=clock, config=config, metrics_name=metrics_name ) self.ratelimiters: DefaultDict[ @@ -197,14 +194,12 @@ class FederationRateLimiter: class _PerHostRatelimiter: def __init__( self, - reactor: IReactorTime, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: - reactor clock config metrics_name: The name of the rate limiter so we can differentiate it @@ -212,7 +207,6 @@ class _PerHostRatelimiter: for this rate limiter. from the rest in the metrics """ - self.reactor = reactor self.clock = clock self.metrics_name = metrics_name @@ -388,4 +382,4 @@ class _PerHostRatelimiter: except KeyError: pass - self.reactor.callLater(0.0, start_next_request) + self.clock.call_later(0.0, start_next_request) diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 2f3ea15b96..fe4961dcf3 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -30,7 +30,7 @@ class FederationRateLimiterTestCase(TestCase): """A simple test with the default values""" reactor, clock = get_clock() rc_config = build_rc_config() - ratelimiter = FederationRateLimiter(reactor, clock, rc_config) + ratelimiter = FederationRateLimiter(clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -40,7 +40,7 @@ class FederationRateLimiterTestCase(TestCase): """Test what happens when we hit the concurrent limit""" reactor, clock = get_clock() rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) - ratelimiter = FederationRateLimiter(reactor, clock, rc_config) + ratelimiter = FederationRateLimiter(clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -67,7 +67,7 @@ class FederationRateLimiterTestCase(TestCase): rc_config = build_rc_config( {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}} ) - ratelimiter = FederationRateLimiter(reactor, clock, rc_config) + ratelimiter = FederationRateLimiter(clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -98,7 +98,7 @@ class FederationRateLimiterTestCase(TestCase): } } ) - ratelimiter = FederationRateLimiter(reactor, clock, rc_config) + ratelimiter = FederationRateLimiter(clock, rc_config) with ratelimiter.ratelimit("testhost") as d: # shouldn't block -- cgit 1.5.1 From 4db3331bb95a655bb56ab8333be49ee183f71715 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2023 14:20:12 +0000 Subject: Add an early return when handling no-op presence updates. (#14855) This stops us from incrementing the presence stream position for no-op updates. --- changelog.d/14855.misc | 1 + synapse/handlers/presence.py | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 changelog.d/14855.misc (limited to 'synapse') diff --git a/changelog.d/14855.misc b/changelog.d/14855.misc new file mode 100644 index 0000000000..f0e292f287 --- /dev/null +++ b/changelog.d/14855.misc @@ -0,0 +1 @@ +Add an early return when handling no-op presence updates. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2af90b25a3..43e4e7b1b4 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -2155,6 +2155,11 @@ class PresenceFederationQueue: # This should only be called on a presence writer. assert self._presence_writer + if not states or not destinations: + # Ignore calls which either don't have any new states or don't need + # to be sent anywhere. + return + if self._federation: self._federation.send_presence_to_destinations( states=states, -- cgit 1.5.1 From db5145a31d8ed76ac637f933f4facc195d557f75 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 16 Jan 2023 23:15:17 +0000 Subject: Add parameter to control whether we do a partial state join (#14843) When the local homeserver is already joined to a room and wants to perform another remote join, we may find it useful to do a non-partial state join if we already have the full state for the room. Signed-off-by: Sean Quah --- changelog.d/14843.misc | 1 + synapse/federation/federation_client.py | 21 ++++++++++++++++++--- synapse/federation/transport/client.py | 7 +++++-- 3 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14843.misc (limited to 'synapse') diff --git a/changelog.d/14843.misc b/changelog.d/14843.misc new file mode 100644 index 0000000000..bec3c216bc --- /dev/null +++ b/changelog.d/14843.misc @@ -0,0 +1 @@ +Add a parameter to control whether the federation client performs a partial state join. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7002e8a6c..15a9a88302 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1014,7 +1014,11 @@ class FederationClient(FederationBase): ) async def send_join( - self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion + self, + destinations: Iterable[str], + pdu: EventBase, + room_version: RoomVersion, + partial_state: bool = True, ) -> SendJoinResult: """Sends a join event to one of a list of homeservers. @@ -1027,6 +1031,10 @@ class FederationClient(FederationBase): pdu: event to be sent room_version: the version of the room (according to the server that did the make_join) + partial_state: whether to ask the remote server to omit membership state + events from the response. If the remote server complies, + `partial_state` in the send join result will be set. Defaults to + `True`. Returns: The result of the send join request. @@ -1037,7 +1045,9 @@ class FederationClient(FederationBase): """ async def send_request(destination: str) -> SendJoinResult: - response = await self._do_send_join(room_version, destination, pdu) + response = await self._do_send_join( + room_version, destination, pdu, omit_members=partial_state + ) # If an event was returned (and expected to be returned): # @@ -1177,7 +1187,11 @@ class FederationClient(FederationBase): ) async def _do_send_join( - self, room_version: RoomVersion, destination: str, pdu: EventBase + self, + room_version: RoomVersion, + destination: str, + pdu: EventBase, + omit_members: bool, ) -> SendJoinResponse: time_now = self._clock.time_msec() @@ -1188,6 +1202,7 @@ class FederationClient(FederationBase): room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), + omit_members=omit_members, ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 5ec651400a..556883f079 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -351,13 +351,16 @@ class TransportLayerClient: room_id: str, event_id: str, content: JsonDict, + omit_members: bool, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) query_params: Dict[str, str] = {} if self._faster_joins_enabled: # lazy-load state on join - query_params["org.matrix.msc3706.partial_state"] = "true" - query_params["omit_members"] = "true" + query_params["org.matrix.msc3706.partial_state"] = ( + "true" if omit_members else "false" + ) + query_params["omit_members"] = "true" if omit_members else "false" return await self.client.put_json( destination=destination, -- cgit 1.5.1 From 2b084c5b710d9630178484e6ade597ca7fa814b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2023 09:29:58 +0000 Subject: Merge device list replication streams (#14833) --- changelog.d/14826.misc | 2 +- changelog.d/14833.misc | 1 + docs/upgrade.md | 9 ++-- synapse/replication/tcp/client.py | 8 +++- synapse/replication/tcp/streams/__init__.py | 3 -- synapse/replication/tcp/streams/_base.py | 74 ++++++++++++++++++++--------- synapse/storage/databases/main/devices.py | 13 ++--- 7 files changed, 72 insertions(+), 38 deletions(-) create mode 100644 changelog.d/14833.misc (limited to 'synapse') diff --git a/changelog.d/14826.misc b/changelog.d/14826.misc index 9ebedcf51e..e80673a721 100644 --- a/changelog.d/14826.misc +++ b/changelog.d/14826.misc @@ -1 +1 @@ -Merge tag and normal account data replication streams. +Merge the two account data and the two device list replication streams. diff --git a/changelog.d/14833.misc b/changelog.d/14833.misc new file mode 100644 index 0000000000..e80673a721 --- /dev/null +++ b/changelog.d/14833.misc @@ -0,0 +1 @@ +Merge the two account data and the two device list replication streams. diff --git a/docs/upgrade.md b/docs/upgrade.md index 8a76172e43..270c33b656 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -92,12 +92,13 @@ process, for example: ## Changes to the account data replication streams -Synapse has changed the format of the account data replication streams (between -workers). This is a forwards- and backwards-incompatible change: v1.75 workers -cannot process account data replicated by v1.76 workers, and vice versa. +Synapse has changed the format of the account data and devices replication +streams (between workers). This is a forwards- and backwards-incompatible +change: v1.75 workers cannot process account data replicated by v1.76 workers, +and vice versa. Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data -replication will resume as normal. +and device replication will resume as normal. # Upgrading to v1.74.0 diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7263bb2796..31022ce5fb 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -187,7 +187,7 @@ class ReplicationDataHandler: elif stream_name == DeviceListsStream.NAME: all_room_ids: Set[str] = set() for row in rows: - if row.entity.startswith("@"): + if row.entity.startswith("@") and not row.is_signature: room_ids = await self.store.get_rooms_for_user(row.entity) all_room_ids.update(room_ids) self.notifier.on_new_event( @@ -422,7 +422,11 @@ class FederationSenderHandler: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - hosts = {row.entity for row in rows if not row.entity.startswith("@")} + hosts = { + row.entity + for row in rows + if not row.entity.startswith("@") and not row.is_signature + } for host in hosts: self.federation_sender.send_device_messages(host, immediate=False) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index a7eadfa3c9..9c67f661a3 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import ( Stream, ToDeviceStream, TypingStream, - UserSignatureStream, ) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.federation import FederationStream @@ -62,7 +61,6 @@ STREAMS_MAP = { ToDeviceStream, FederationStream, AccountDataStream, - UserSignatureStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, ) @@ -82,7 +80,6 @@ __all__ = [ "DeviceListsStream", "ToDeviceStream", "AccountDataStream", - "UserSignatureStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", ] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index fbf78da9c2..a4bdb48c0c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -463,18 +463,67 @@ class DeviceListsStream(Stream): @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: entity: str + # Indicates that a user has signed their own device with their user-signing key + is_signature: bool NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_device_list_changes_for_remotes, + current_token_without_instance(self.store.get_device_stream_token), + self._update_function, + ) + + async def _update_function( + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, + ) -> StreamUpdateResult: + ( + device_updates, + devices_to_token, + devices_limited, + ) = await self.store.get_all_device_list_changes_for_remotes( + instance_name, from_token, current_token, target_row_count ) + ( + signatures_updates, + signatures_to_token, + signatures_limited, + ) = await self.store.get_all_user_signature_changes_for_remotes( + instance_name, from_token, current_token, target_row_count + ) + + upper_limit_token = current_token + if devices_limited: + upper_limit_token = min(upper_limit_token, devices_to_token) + if signatures_limited: + upper_limit_token = min(upper_limit_token, signatures_to_token) + + device_updates = [ + (stream_id, (entity, False)) + for stream_id, (entity,) in device_updates + if stream_id <= upper_limit_token + ] + + signatures_updates = [ + (stream_id, (entity, True)) + for stream_id, (entity,) in signatures_updates + if stream_id <= upper_limit_token + ] + + updates = list( + heapq.merge(device_updates, signatures_updates, key=lambda row: row[0]) + ) + + return updates, upper_limit_token, devices_limited or signatures_limited + class ToDeviceStream(Stream): """New to_device messages for a client""" @@ -583,22 +632,3 @@ class AccountDataStream(Stream): heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) ) return updates, to_token, limited - - -class UserSignatureStream(Stream): - """A user has signed their own device with their user-signing key""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class UserSignatureStreamRow: - user_id: str - - NAME = "user_signature" - ROW_TYPE = UserSignatureStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_device_stream_token), - store.get_all_user_signature_changes_for_remotes, - ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b067664473..cd186c8472 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,7 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream +from synapse.replication.tcp.streams._base import DeviceListsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) -> None: if stream_name == DeviceListsStream.NAME: self._invalidate_caches_for_devices(token, rows) - elif stream_name == UserSignatureStream.NAME: - for row in rows: - self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) def process_replication_position( @@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) - elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: for row in rows: + if row.is_signature: + self._user_signature_stream_cache.entity_has_changed(row.entity, token) + continue + # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. -- cgit 1.5.1 From 316590d1ea273115a9e7925236e02d577a231de4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2023 09:58:22 +0000 Subject: Fix bug in `wait_for_stream_position` (#14856) We were incorrectly checking if the *local* token had been advanced, rather than the token for the remote instance. In practice, I don't think this has caused any bugs due to where we use `wait_for_stream_position`, as critically we don't use it on instances that also write to the given streams (and so the local token will lag behind all remote tokens). --- changelog.d/14856.misc | 1 + synapse/replication/tcp/client.py | 2 +- tests/replication/tcp/test_handler.py | 78 +++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14856.misc (limited to 'synapse') diff --git a/changelog.d/14856.misc b/changelog.d/14856.misc new file mode 100644 index 0000000000..3731d6cbf1 --- /dev/null +++ b/changelog.d/14856.misc @@ -0,0 +1 @@ +Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 31022ce5fb..322d695bc7 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -325,7 +325,7 @@ class ReplicationDataHandler: # anyway in that case we don't need to wait. return - current_position = self._streams[stream_name].current_token(self._instance_name) + current_position = self._streams[stream_name].current_token(instance_name) if position <= current_position: # We're already past the position return diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index 1e299d2d67..555922409d 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.replication.tcp.commands import PositionCommand, RdataCommand + from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -71,3 +75,77 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual( len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 ) + + def test_wait_for_stream_position(self) -> None: + """Check that wait for stream position correctly waits for an update from the + correct instance. + """ + store = self.hs.get_datastores().main + cmd_handler = self.hs.get_replication_command_handler() + data_handler = self.hs.get_replication_data_handler() + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + cache_id_gen = worker1.get_datastores().main._cache_id_gen + assert cache_id_gen is not None + + self.replicate() + + # First, make sure the master knows that `worker1` exists. + initial_token = cache_id_gen.get_current_token() + cmd_handler.send_command( + PositionCommand("caches", "worker1", initial_token, initial_token) + ) + self.replicate() + + # Next send out a normal RDATA, and check that waiting for that stream + # ID returns immediately. + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + cmd_handler.send_command( + RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) + ) + self.replicate() + + self.get_success( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + + # `wait_for_stream_position` should only return once master receives an + # RDATA from the worker + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... updating the cache ID gen on the master still shouldn't cause the + # deferred to wake up. + ctx = store._cache_id_gen.get_next() + self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... but receiving the RDATA should + cmd_handler.send_command( + RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) + ) + self.replicate() + + self.assertTrue(d.called) -- cgit 1.5.1 From 5b3af1c7d0c5a8901fada7648136f186726fd135 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 17 Jan 2023 12:44:15 +0000 Subject: Stabilise serving partial join responses (#14839) Serving partial join responses is no longer experimental. They will only be served under the stable identifier if the the undocumented config flag experimental.msc3706_enabled is set to true. Synapse continues to request a partial join only if the undocumented config flag experimental.faster_joins is set to true; this setting remains present and unaffected. --- changelog.d/14839.feature | 1 + docker/complement/conf/workers-shared-extra.yaml.j2 | 2 -- synapse/config/experimental.py | 6 +++++- synapse/federation/transport/server/federation.py | 21 ++++++++++----------- tests/federation/test_federation_server.py | 3 +-- 5 files changed, 17 insertions(+), 16 deletions(-) create mode 100644 changelog.d/14839.feature (limited to 'synapse') diff --git a/changelog.d/14839.feature b/changelog.d/14839.feature new file mode 100644 index 0000000000..a4206be007 --- /dev/null +++ b/changelog.d/14839.feature @@ -0,0 +1 @@ +Faster joins: always serve a partial join response to servers that request it with the stable query param. diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 1170694df5..7e9ec23808 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -94,8 +94,6 @@ allow_device_name_lookup_over_federation: true experimental_features: # Enable history backfilling support msc2716_enabled: true - # server-side support for partial state in /send_join responses - msc3706_enabled: true {% if not workers_in_use %} # client-side support for partial state in /send_join responses faster_joins: true diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 72a17e0616..0444ef8244 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -75,11 +75,15 @@ class ExperimentalConfig(Config): ) # MSC3706 (server-side support for partial state in /send_join responses) + # Synapse will always serve partial state responses to requests using the stable + # query parameter `omit_members`. If this flag is set, Synapse will also serve + # partial state responses to requests using the unstable query parameter + # `org.matrix.msc3706.partial_state`. self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) # experimental support for faster joins over federation # (MSC2775, MSC3706, MSC3895) - # requires a target server with msc3706_enabled enabled. + # requires a target server that can provide a partial join response (MSC3706) self.faster_joins_enabled: bool = experimental.get("faster_joins", False) # MSC3720 (Account status endpoint) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index c0a700905b..17c427387e 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -422,7 +422,7 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._msc3706_enabled = hs.config.experimental.msc3706_enabled + self._read_msc3706_query_param = hs.config.experimental.msc3706_enabled async def on_PUT( self, @@ -436,16 +436,15 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): # match those given in content partial_state = False - if self._msc3706_enabled: - # The stable query parameter wins, if it disagrees with the unstable - # parameter for some reason. - stable_param = parse_boolean_from_args(query, "omit_members", default=None) - if stable_param is not None: - partial_state = stable_param - else: - partial_state = parse_boolean_from_args( - query, "org.matrix.msc3706.partial_state", default=False - ) + # The stable query parameter wins, if it disagrees with the unstable + # parameter for some reason. + stable_param = parse_boolean_from_args(query, "omit_members", default=None) + if stable_param is not None: + partial_state = stable_param + elif self._read_msc3706_query_param: + partial_state = parse_boolean_from_args( + query, "org.matrix.msc3706.partial_state", default=False + ) result = await self.handler.on_send_join_request( origin, content, room_id, caller_supports_partial_state=partial_state diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 27770304be..be719e49c0 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -211,9 +211,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") - @override_config({"experimental_features": {"msc3706_enabled": True}}) def test_send_join_partial_state(self) -> None: - """When MSC3706 support is enabled, /send_join should return partial state""" + """/send_join should return partial state, if requested""" joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME join_result = self._make_join(joining_user) -- cgit 1.5.1 From 4d6b1d3c47387466d34abb98613ca0d240057e24 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 18 Jan 2023 09:27:57 -0500 Subject: Properly check for frozendicts in event auth code. (#14864) Check for for an instance of a mapping instead of a dict. This only affects room version 10 when frozen events are enabled. --- changelog.d/14864.bugfix | 1 + synapse/event_auth.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14864.bugfix (limited to 'synapse') diff --git a/changelog.d/14864.bugfix b/changelog.d/14864.bugfix new file mode 100644 index 0000000000..12c0c74ab3 --- /dev/null +++ b/changelog.d/14864.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.64.0 when using room version 10 with frozen events enabled. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index d437b7e5d1..c4a7b16413 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import logging import typing from typing import ( @@ -877,7 +878,7 @@ def _check_power_levels( if not isinstance(v, int): raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: - if not isinstance(v, dict) or not all( + if not isinstance(v, collections.abc.Mapping) or not all( isinstance(v, int) for v in v.values() ): raise SynapseError( -- cgit 1.5.1 From e8f2bf5c40c27e68e5983ebbd1fc0281bc45bf5f Mon Sep 17 00:00:00 2001 From: Catalan Lover <48515417+FSG-Cat@users.noreply.github.com> Date: Wed, 18 Jan 2023 19:59:48 +0100 Subject: Change default room version to 10. Implements MSC3904 (#14111) * Change Documentation to have v10 as default room version * Change Default Room version to 10 * Add changelog entry for default room version swap * Add changelog entry for v10 default room version in docs * Clarify doc changelog entry Co-authored-by: David Robertson * Improve Documentation changes. Co-authored-by: David Robertson * Update Changelog entry to have correct format Co-authored-by: David Robertson * Update Spec Version to 1.5 * Only need 1 changelog. * Fix test. * Update "Changed in" line Co-authored-by: David Robertson Co-authored-by: Patrick Cloke Co-authored-by: Patrick Cloke --- changelog.d/14111.feature | 1 + docs/usage/configuration/config_documentation.md | 4 +++- synapse/config/server.py | 2 +- tests/rest/client/test_upgrade_room.py | 12 +++++++++--- 4 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14111.feature (limited to 'synapse') diff --git a/changelog.d/14111.feature b/changelog.d/14111.feature new file mode 100644 index 0000000000..0a794701a7 --- /dev/null +++ b/changelog.d/14111.feature @@ -0,0 +1 @@ +Update the default room version to [v10](https://spec.matrix.org/v1.5/rooms/v10/) ([MSC 3904](https://github.com/matrix-org/matrix-spec-proposals/pull/3904)). Contributed by @FSG-Cat. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 3481e866f7..2883f76a26 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -295,7 +295,9 @@ Known room versions are listed [here](https://spec.matrix.org/latest/rooms/#comp For example, for room version 1, `default_room_version` should be set to "1". -Currently defaults to "9". +Currently defaults to ["10"](https://spec.matrix.org/v1.5/rooms/v10/). + +_Changed in Synapse 1.76:_ the default version room version was increased from [9](https://spec.matrix.org/v1.5/rooms/v9/) to [10](https://spec.matrix.org/v1.5/rooms/v10/). Example configuration: ```yaml diff --git a/synapse/config/server.py b/synapse/config/server.py index ec46ca63ad..80bcfa4080 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -151,7 +151,7 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "fec0::/10", ] -DEFAULT_ROOM_VERSION = "9" +DEFAULT_ROOM_VERSION = "10" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 5e7bf97482..5ec343dd7f 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -199,9 +199,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): def test_stringy_power_levels(self) -> None: """The room upgrade converts stringy power levels to proper integers.""" + # Create a room on room version < 10. + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_token, room_version="9" + ) + self.helper.join(room_id, self.other, tok=self.other_token) + # Retrieve the room's current power levels. power_levels = self.helper.get_state( - self.room_id, + room_id, "m.room.power_levels", tok=self.creator_token, ) @@ -217,14 +223,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # conscience, we ought to ensure it's upgrading from a sufficiently old # version of room. self.helper.send_state( - self.room_id, + room_id, "m.room.power_levels", body=power_levels, tok=self.creator_token, ) # Upgrade the room. Check the homeserver reports success. - channel = self._upgrade_room() + channel = self._upgrade_room(room_id=room_id) self.assertEqual(200, channel.code, channel.result) # Extract the new room ID. -- cgit 1.5.1 From 9187fd940e2b2bbfd4df7204053cc26b2707aad4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Jan 2023 19:35:29 +0000 Subject: Wait for streams to catch up when processing HTTP replication. (#14820) This should hopefully mitigate a class of races where data gets out of sync due a HTTP replication request racing with the replication streams. --- changelog.d/14820.bugfix | 1 + synapse/handlers/federation_event.py | 4 ++ synapse/replication/http/_base.py | 97 +++++++++++++++++++++++++++++--- synapse/replication/http/account_data.py | 29 +++++----- synapse/replication/http/devices.py | 10 +--- synapse/replication/http/federation.py | 28 +++------ synapse/replication/http/login.py | 5 +- synapse/replication/http/membership.py | 22 ++++---- synapse/replication/http/presence.py | 7 +-- synapse/replication/http/push.py | 5 +- synapse/replication/http/register.py | 9 +-- synapse/replication/http/send_event.py | 5 +- synapse/replication/http/send_events.py | 4 +- synapse/replication/http/state.py | 2 +- synapse/replication/http/streams.py | 6 +- synapse/replication/tcp/client.py | 25 +++++++- synapse/replication/tcp/resource.py | 43 +++++++------- synapse/storage/util/id_generators.py | 34 ++++++----- synapse/types/__init__.py | 6 ++ tests/replication/http/test__base.py | 9 +-- tests/storage/test_id_generators.py | 20 +++---- 21 files changed, 226 insertions(+), 145 deletions(-) create mode 100644 changelog.d/14820.bugfix (limited to 'synapse') diff --git a/changelog.d/14820.bugfix b/changelog.d/14820.bugfix new file mode 100644 index 0000000000..36e94f2b9b --- /dev/null +++ b/changelog.d/14820.bugfix @@ -0,0 +1 @@ +Fix rare races when using workers. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6df000faaf..904a721483 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2259,6 +2259,10 @@ class FederationEventHandler: event_and_contexts, backfilled=backfilled ) + # After persistence we always need to notify replication there may + # be new data. + self._notifier.notify_replication() + if self._ephemeral_messages_enabled: for event in events: # If there's an expiry timestamp on the event, schedule its expiry. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 3f4d3fc51a..709327b97f 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,7 +17,7 @@ import logging import re import urllib.parse from inspect import signature -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -27,6 +27,7 @@ from twisted.web.server import Request from synapse.api.errors import HttpResponseException, SynapseError from synapse.http import RequestTimedOutError from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.opentracing import trace_with_opname @@ -53,6 +54,9 @@ _outgoing_request_counter = Counter( ) +_STREAM_POSITION_KEY = "_INT_STREAM_POS" + + class ReplicationEndpoint(metaclass=abc.ABCMeta): """Helper base class for defining new replication HTTP endpoints. @@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): a connection error is received. RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when receiving connection errors, each will backoff exponentially longer. + WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to + catch up before processing the request and/or response. Defaults to + True. """ NAME: str = abc.abstractproperty() # type: ignore @@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): RETRY_ON_CONNECT_ERROR = True RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) + WAIT_FOR_STREAMS: ClassVar[bool] = True + def __init__(self, hs: "HomeServer"): if self.CACHE: self.response_cache: ResponseCache[str] = ResponseCache( @@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if hs.config.worker.worker_replication_secret: self._replication_secret = hs.config.worker.worker_replication_secret + self._streams = hs.get_replication_command_handler().get_streams_to_replicate() + self._replication = hs.get_replication_data_handler() + self._instance_name = hs.get_instance_name() + def _check_auth(self, request: Request) -> None: # Get the authorization header. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") @@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): @abc.abstractmethod async def _handle_request( - self, request: Request, **kwargs: Any + self, request: Request, content: JsonDict, **kwargs: Any ) -> Tuple[int, JsonDict]: """Handle incoming request. @@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): @trace_with_opname("outgoing_replication_request") async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: + # We have to pull these out here to avoid circular dependencies... + streams = hs.get_replication_command_handler().get_streams_to_replicate() + replication = hs.get_replication_data_handler() + with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") @@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): data = await cls._serialize_payload(**kwargs) + if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS: + # Include the current stream positions that we write to. We + # don't do this for GETs as they don't have a body, and we + # generally assume that a GET won't rely on data we have + # written. + if _STREAM_POSITION_KEY in data: + raise Exception( + "data to send contains %r key", _STREAM_POSITION_KEY + ) + + data[_STREAM_POSITION_KEY] = { + "streams": { + stream.NAME: stream.current_token(local_instance_name) + for stream in streams + }, + "instance_name": local_instance_name, + } + url_args = [ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] @@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() + + # Wait on any streams that the remote may have written to. + for stream_name, position in result.get( + _STREAM_POSITION_KEY, {} + ).items(): + await replication.wait_for_stream_position( + instance_name=instance_name, + stream_name=stream_name, + position=position, + raise_on_timeout=False, + ) + return result return send_request @@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if self._replication_secret: self._check_auth(request) + if self.METHOD == "GET": + # GET APIs always have an empty body. + content = {} + else: + content = parse_json_object_from_request(request) + + # Wait on any streams that the remote may have written to. + for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[ + "streams" + ].items(): + await self._replication.wait_for_stream_position( + instance_name=content[_STREAM_POSITION_KEY]["instance_name"], + stream_name=stream_name, + position=position, + raise_on_timeout=False, + ) + if self.CACHE: txn_id = kwargs.pop("txn_id") @@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # correctly yet. In particular, there may be issues to do with logging # context lifetimes. - return await self.response_cache.wrap( - txn_id, self._handle_request, request, **kwargs + code, response = await self.response_cache.wrap( + txn_id, self._handle_request, request, content, **kwargs ) + else: + # The `@cancellable` decorator may be applied to `_handle_request`. But we + # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, + # so we have to set up the cancellable flag ourselves. + request.is_render_cancellable = is_function_cancellable( + self._handle_request + ) + + code, response = await self._handle_request(request, content, **kwargs) + + # Return streams we may have written to in the course of processing this + # request. + if _STREAM_POSITION_KEY in response: + raise Exception("data to send contains %r key", _STREAM_POSITION_KEY) - # The `@cancellable` decorator may be applied to `_handle_request`. But we - # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, - # so we have to set up the cancellable flag ourselves. - request.is_render_cancellable = is_function_cancellable(self._handle_request) + if self.WAIT_FOR_STREAMS: + response[_STREAM_POSITION_KEY] = { + stream.NAME: stream.current_token(self._instance_name) + for stream in self._streams + } - return await self._handle_request(request, **kwargs) + return code, response diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 0edc95977b..2374f810c9 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -61,10 +60,8 @@ class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint): return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, account_data_type: str + self, request: Request, content: JsonDict, user_id: str, account_data_type: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_account_data_for_user( user_id, account_data_type, content["content"] ) @@ -101,7 +98,7 @@ class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, account_data_type: str + self, request: Request, content: JsonDict, user_id: str, account_data_type: str ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_account_data_for_user( user_id, account_data_type @@ -143,10 +140,13 @@ class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint): return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, account_data_type: str + self, + request: Request, + content: JsonDict, + user_id: str, + room_id: str, + account_data_type: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, content["content"] ) @@ -183,7 +183,12 @@ class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, account_data_type: str + self, + request: Request, + content: JsonDict, + user_id: str, + room_id: str, + account_data_type: str, ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_account_data_for_room( user_id, room_id, account_data_type @@ -225,10 +230,8 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, tag: str + self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - max_stream_id = await self.handler.add_tag_to_room( user_id, room_id, tag, content["content"] ) @@ -266,7 +269,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str, room_id: str, tag: str + self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_tag_from_room( user_id, diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index ea5c08e6cf..ecea6fc915 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -78,7 +77,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) @@ -138,9 +137,8 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): return {"user_ids": user_ids} async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: - content = parse_json_object_from_request(request) user_ids: List[str] = content["user_ids"] logger.info("Resync for %r", user_ids) @@ -205,10 +203,8 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - user_id = content["user_id"] device_id = content["device_id"] keys = content["keys"] diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d3abafed28..53ad327030 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict from synapse.util.metrics import Measure @@ -114,10 +113,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override] + async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override] with Measure(self.clock, "repl_fed_send_events_parse"): - content = parse_json_object_from_request(request) - room_id = content["room_id"] backfilled = content["backfilled"] @@ -181,13 +178,10 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): return {"origin": origin, "content": content} async def _handle_request( # type: ignore[override] - self, request: Request, edu_type: str + self, request: Request, content: JsonDict, edu_type: str ) -> Tuple[int, JsonDict]: - with Measure(self.clock, "repl_fed_send_edu_parse"): - content = parse_json_object_from_request(request) - - origin = content["origin"] - edu_content = content["content"] + origin = content["origin"] + edu_content = content["content"] logger.info("Got %r edu from %s", edu_type, origin) @@ -231,13 +225,10 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): return {"args": args} async def _handle_request( # type: ignore[override] - self, request: Request, query_type: str + self, request: Request, content: JsonDict, query_type: str ) -> Tuple[int, JsonDict]: - with Measure(self.clock, "repl_fed_query_parse"): - content = parse_json_object_from_request(request) - - args = content["args"] - args["origin"] = content["origin"] + args = content["args"] + args["origin"] = content["origin"] logger.info("Got %r query from %s", query_type, args["origin"]) @@ -274,7 +265,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: await self.store.clean_room_for_join(room_id) @@ -307,9 +298,8 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): return {"room_version": room_version.identifier} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index c68e18da12..6ad6cb1bfe 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -73,10 +72,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 663bff5738..9fa1060d48 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID @@ -79,10 +78,8 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, request: SynapseRequest, room_id: str, user_id: str + self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] @@ -147,11 +144,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: SynapseRequest, + content: JsonDict, room_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] @@ -217,10 +213,8 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, request: SynapseRequest, invite_event_id: str + self, request: SynapseRequest, content: JsonDict, invite_event_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - txn_id = content["txn_id"] event_content = content["content"] @@ -285,10 +279,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: SynapseRequest, + content: JsonDict, knock_event_id: str, ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - txn_id = content["txn_id"] event_content = content["content"] @@ -347,7 +340,12 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str, user_id: str, change: str + self, + request: Request, + content: JsonDict, + room_id: str, + user_id: str, + change: str, ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index 4a5b08f56f..db16aac9c2 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, UserID @@ -56,7 +55,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( UserID.from_string(user_id) @@ -107,10 +106,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - await self._presence_handler.set_state( UserID.from_string(user_id), content["state"], diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index af5c2f66a7..297e8ad564 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -61,10 +60,8 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return payload async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - app_id = content["app_id"] pushkey = content["pushkey"] diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 976c283360..265e601b96 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -96,10 +95,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - await self.registration_handler.check_registration_ratelimit(content["address"]) # Always default admin users to approved (since it means they were created by @@ -150,10 +147,8 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): return {"auth_result": auth_result, "access_token": access_token} async def _handle_request( # type: ignore[override] - self, request: Request, user_id: str + self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - auth_result = content["auth_result"] access_token = content["access_token"] diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 4215a1c1bc..27ad914075 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure @@ -114,11 +113,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload async def _handle_request( # type: ignore[override] - self, request: Request, event_id: str + self, request: Request, content: JsonDict, event_id: str ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_event_parse"): - content = parse_json_object_from_request(request) - event_dict = content["event"] room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]] internal_metadata = content["internal_metadata"] diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py index 8889bbb644..4f82c9f96d 100644 --- a/synapse/replication/http/send_events.py +++ b/synapse/replication/http/send_events.py @@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.server import HttpServer -from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure @@ -114,10 +113,9 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint): return payload async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, payload: JsonDict ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_events_parse"): - payload = parse_json_object_from_request(request) events_and_context = [] events = payload["events"] diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py index 838b7584e5..0c524e7de3 100644 --- a/synapse/replication/http/state.py +++ b/synapse/replication/http/state.py @@ -57,7 +57,7 @@ class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request, room_id: str + self, request: Request, content: JsonDict, room_id: str ) -> Tuple[int, JsonDict]: writer_instance = self._events_shard_config.get_instance(room_id) if writer_instance != self._instance_name: diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index c065225362..3c7b5b18ea 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): PATH_ARGS = ("stream_name",) METHOD = "GET" + # We don't want to wait for replication streams to catch up, as this gets + # called in the process of catching replication streams up. + WAIT_FOR_STREAMS = False + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -67,7 +71,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): return {"from_token": from_token, "upto_token": upto_token} async def _handle_request( # type: ignore[override] - self, request: Request, stream_name: str + self, request: Request, content: JsonDict, stream_name: str ) -> Tuple[int, JsonDict]: stream = self.streams.get(stream_name) if stream is None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 322d695bc7..5c2482e40c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory @@ -314,10 +315,21 @@ class ReplicationDataHandler: self.send_handler.wake_destination(server) async def wait_for_stream_position( - self, instance_name: str, stream_name: str, position: int + self, + instance_name: str, + stream_name: str, + position: int, + raise_on_timeout: bool = True, ) -> None: """Wait until this instance has received updates up to and including the given stream position. + + Args: + instance_name + stream_name + position + raise_on_timeout: Whether to raise an exception if we time out + waiting for the updates, or if we log an error and return. """ if instance_name == self._instance_name: @@ -345,7 +357,16 @@ class ReplicationDataHandler: # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): logger.info("Waiting for repl stream %r to reach %s", stream_name, position) - await make_deferred_yieldable(deferred) + try: + await make_deferred_yieldable(deferred) + except defer.TimeoutError: + logger.error("Timed out waiting for stream %s", stream_name) + + if raise_on_timeout: + raise + + return + logger.info( "Finished waiting for repl stream %r to reach %s", stream_name, position ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 99f09669f0..9d17eff714 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -199,33 +199,28 @@ class ReplicationStreamer: # The token has advanced but there is no data to # send, so we send a `POSITION` to inform other # workers of the updated position. - if stream.NAME == EventsStream.NAME: - # XXX: We only do this for the EventStream as it - # turns out that e.g. account data streams share - # their "current token" with each other, meaning - # that it is *not* safe to send a POSITION. - - # Note: `last_token` may not *actually* be the - # last token we sent out in a RDATA or POSITION. - # This can happen if we sent out an RDATA for - # position X when our current token was say X+1. - # Other workers will see RDATA for X and then a - # POSITION with last token of X+1, which will - # cause them to check if there were any missing - # updates between X and X+1. - logger.info( - "Sending position: %s -> %s", + + # Note: `last_token` may not *actually* be the + # last token we sent out in a RDATA or POSITION. + # This can happen if we sent out an RDATA for + # position X when our current token was say X+1. + # Other workers will see RDATA for X and then a + # POSITION with last token of X+1, which will + # cause them to check if there were any missing + # updates between X and X+1. + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( stream.NAME, + self._instance_name, + last_token, current_token, ) - self.command_handler.send_command( - PositionCommand( - stream.NAME, - self._instance_name, - last_token, - current_token, - ) - ) + ) continue # Some streams return multiple rows with the same stream IDs, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0d7108f01b..8670ffbfa3 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -378,6 +378,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._current_positions.values(), default=1 ) + if not writers: + # If there have been no explicit writers given then any instance can + # write to the stream. In which case, let's pre-seed our own + # position with the current minimum. + self._current_positions[self._instance_name] = self._persisted_upto_position + def _load_current_ids( self, db_conn: LoggingDatabaseConnection, @@ -695,24 +701,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): heapq.heappush(self._known_persisted_positions, new_id) - # If we're a writer and we don't have any active writes we update our - # current position to the latest position seen. This allows the instance - # to report a recent position when asked, rather than a potentially old - # one (if this instance hasn't written anything for a while). - our_current_position = self._current_positions.get(self._instance_name) - if ( - our_current_position - and not self._unfinished_ids - and not self._in_flight_fetches - ): - self._current_positions[self._instance_name] = max( - our_current_position, new_id - ) - # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values(), default=0) + our_current_position = self._current_positions.get(self._instance_name, 0) + min_curr = min( + ( + token + for name, token in self._current_positions.items() + if name != self._instance_name + ), + default=our_current_position, + ) + + if our_current_position and (self._unfinished_ids or self._in_flight_fetches): + min_curr = min(min_curr, our_current_position) + self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 0c725eb967..c59eca2430 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -604,6 +604,12 @@ class RoomStreamToken: elif self.instance_map: entries = [] for name, pos in self.instance_map.items(): + if pos <= self.stream: + # Ignore instances who are below the minimum stream position + # (we might know they've advanced without seeing a recent + # write from them). + continue + instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py index 936ab4504a..e03d9b4cc0 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py @@ -44,7 +44,7 @@ class CancellableReplicationEndpoint(ReplicationEndpoint): @cancellable async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: await self.clock.sleep(1.0) return HTTPStatus.OK, {"result": True} @@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint): NAME = "uncancellable_sleep" PATH_ARGS = () CACHE = False + WAIT_FOR_STREAMS = False def __init__(self, hs: HomeServer): super().__init__(hs) @@ -64,7 +65,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint): return {} async def _handle_request( # type: ignore[override] - self, request: Request + self, request: Request, content: JsonDict ) -> Tuple[int, JsonDict]: await self.clock.sleep(1.0) return HTTPStatus.OK, {"result": True} @@ -85,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase): def test_cancellable_disconnect(self) -> None: """Test that handlers with the `@cancellable` flag can be cancelled.""" path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" - channel = self.make_request("POST", path, await_result=False) + channel = self.make_request("POST", path, await_result=False, content={}) test_disconnect( self.reactor, channel, @@ -96,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase): def test_uncancellable_disconnect(self) -> None: """Test that handlers without the `@cancellable` flag cannot be cancelled.""" path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" - channel = self.make_request("POST", path, await_result=False) + channel = self.make_request("POST", path, await_result=False, content={}) test_disconnect( self.reactor, channel, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index d6a2b8d274..ff9691c518 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -349,8 +349,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # The first ID gen will notice that it can advance its token to 7 as it # has no in progress writes... - self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # ... but the second ID gen doesn't know that. @@ -366,8 +366,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual( - first_id_gen.get_positions(), {"first": 7, "second": 7} + first_id_gen.get_positions(), {"first": 3, "second": 7} ) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.get_success(_get_next_async()) @@ -473,7 +474,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator("first", writers=["first", "second"]) - self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -720,7 +721,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async2()) - self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2}) + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) @@ -816,15 +817,12 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): first_id_gen = self._create_id_generator("first", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - # The first ID gen will notice that it can advance its token to 7 as it - # has no in progress writes... - self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - # ... but the second ID gen doesn't know that. self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + self.assertEqual(second_id_gen.get_persisted_upto_position(), 7) -- cgit 1.5.1 From a7b54ca8d84e9371244d792c30fc9084579470e1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 19 Jan 2023 12:47:10 +0000 Subject: Implement MSC3930: polls push rules (#14787) --- changelog.d/14787.feature | 1 + .../complement/conf/workers-shared-extra.yaml.j2 | 6 +- rust/benches/evaluator.rs | 9 ++- rust/src/push/base_rules.rs | 78 +++++++++++++++++++++- rust/src/push/evaluator.rs | 2 +- rust/src/push/mod.rs | 16 +++-- scripts-dev/complement.sh | 2 +- stubs/synapse/synapse_rust/push.pyi | 3 +- synapse/config/experimental.py | 7 ++ synapse/storage/databases/main/push_rule.py | 3 +- 10 files changed, 114 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14787.feature (limited to 'synapse') diff --git a/changelog.d/14787.feature b/changelog.d/14787.feature new file mode 100644 index 0000000000..6a34035047 --- /dev/null +++ b/changelog.d/14787.feature @@ -0,0 +1 @@ +Implement experimental support for MSC3930: Push rules for (MSC3381) Polls. \ No newline at end of file diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 7e9ec23808..281157846a 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -98,12 +98,14 @@ experimental_features: # client-side support for partial state in /send_join responses faster_joins: true {% endif %} - # Filtering /messages by relation type. - msc3874_enabled: true + # Enable support for polls + msc3381_polls_enabled: true # Enable deleting device-specific notification settings stored in account data msc3890_enabled: true # Enable removing account data support msc3391_enabled: true + # Filtering /messages by relation type. + msc3874_enabled: true server_notices: system_mxid_localpart: _server diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 442a79348f..8c28bb0af3 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -150,8 +150,13 @@ fn bench_eval_message(b: &mut Bencher) { ) .unwrap(); - let rules = - FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false); + let rules = FilteredPushRules::py_new( + PushRules::new(Vec::new()), + Default::default(), + false, + false, + false, + ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 35129691ca..9140a69bb6 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -208,6 +208,20 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.org.matrix.msc3930.rule.poll_response"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.response")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, ]; pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { @@ -596,6 +610,68 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_start_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.start")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_start"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.start")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_end_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.end")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3930.rule.poll_end"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.end")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify]), + default: true, + default_enabled: true, + }, ]; lazy_static! { diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index c901c0fbcc..0242ee1c5f 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -483,7 +483,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, true), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 2e9d3e38a1..842b13c88b 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -411,8 +411,9 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3664_enabled: bool, msc1767_enabled: bool, + msc3381_polls_enabled: bool, + msc3664_enabled: bool, } #[pymethods] @@ -421,14 +422,16 @@ impl FilteredPushRules { pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, - msc3664_enabled: bool, msc1767_enabled: bool, + msc3381_polls_enabled: bool, + msc3664_enabled: bool, ) -> Self { Self { push_rules, enabled_map, - msc3664_enabled, msc1767_enabled, + msc3381_polls_enabled, + msc3664_enabled, } } @@ -447,13 +450,18 @@ impl FilteredPushRules { .iter() .filter(|rule| { // Ignore disabled experimental push rules + + if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + return false; + } + if !self.msc3664_enabled && rule.rule_id == "global/override/.im.nheko.msc3664.reply" { return false; } - if !self.msc1767_enabled && rule.rule_id.contains("org.matrix.msc1767") { + if !self.msc3381_polls_enabled && rule.rule_id.contains("org.matrix.msc3930") { return false; } diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 7c48d8bccb..a183653d52 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -190,7 +190,7 @@ fi extra_test_args=() -test_tags="synapse_blacklist,msc3787,msc3874,msc3890,msc3391" +test_tags="synapse_blacklist,msc3787,msc3874,msc3890,msc3391,msc3930" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 373b40740b..304ed7111c 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -43,8 +43,9 @@ class FilteredPushRules: self, push_rules: PushRules, enabled_map: Dict[str, bool], - msc3664_enabled: bool, msc1767_enabled: bool, + msc3381_polls_enabled: bool, + msc3664_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 0444ef8244..89586db763 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -146,6 +146,13 @@ class ExperimentalConfig(Config): "required to communicate account data deletions to clients." ) + # MSC3381: Polls. + # In practice, supporting polls in Synapse only requires an implementation of + # MSC3930: Push rules for MSC3391 polls; which is what this option enables. + self.msc3381_polls_enabled: bool = experimental.get( + "msc3381_polls_enabled", False + ) + # MSC3912: Relation-based redactions. self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4e4b777da..03182887d1 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -86,8 +86,9 @@ def _load_rules( filtered_rules = FilteredPushRules( push_rules, enabled_map, - msc3664_enabled=experimental_config.msc3664_enabled, msc1767_enabled=experimental_config.msc1767_enabled, + msc3664_enabled=experimental_config.msc3664_enabled, + msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, ) return filtered_rules -- cgit 1.5.1 From cdf2707678dc9f08e965eb0f0c1f39e71552fe3e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Jan 2023 22:19:56 +0000 Subject: Fix bug in wait for stream position (#14872) This caused some requests to fail. This caused some requests to fail. This really only started causing issues due to #14856 --- changelog.d/14872.misc | 1 + synapse/replication/tcp/client.py | 29 +++++++++++++++++++---------- 2 files changed, 20 insertions(+), 10 deletions(-) create mode 100644 changelog.d/14872.misc (limited to 'synapse') diff --git a/changelog.d/14872.misc b/changelog.d/14872.misc new file mode 100644 index 0000000000..3731d6cbf1 --- /dev/null +++ b/changelog.d/14872.misc @@ -0,0 +1 @@ +Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5c2482e40c..6e242c5749 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -133,9 +133,9 @@ class ReplicationDataHandler: if hs.should_send_federation(): self.send_handler = FederationSenderHandler(hs) - # Map from stream to list of deferreds waiting for the stream to + # Map from stream and instance to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} + self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -270,7 +270,7 @@ class ReplicationDataHandler: # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position. - waiting_list = self._streams_to_waiters.get(stream_name, []) + waiting_list = self._streams_to_waiters.get((stream_name, instance_name), []) # Index of first item with a position after the current token, i.e we # have called all deferreds before this index. If not overwritten by @@ -279,14 +279,13 @@ class ReplicationDataHandler: # `len(list)` works for both cases. index_of_first_deferred_not_called = len(waiting_list) + # We don't fire the deferreds until after we finish iterating over the + # list, to avoid the list changing when we fire the deferreds. + deferreds_to_callback = [] + for idx, (position, deferred) in enumerate(waiting_list): if position <= token: - try: - with PreserveLoggingContext(): - deferred.callback(None) - except Exception: - # The deferred has been cancelled or timed out. - pass + deferreds_to_callback.append(deferred) else: # The list is sorted by position so we don't need to continue # checking any further entries in the list. @@ -297,6 +296,14 @@ class ReplicationDataHandler: # loop. (This maintains the order so no need to resort) waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + for deferred in deferreds_to_callback: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + async def on_position( self, stream_name: str, instance_name: str, token: int ) -> None: @@ -349,7 +356,9 @@ class ReplicationDataHandler: deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor ) - waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + waiting_list = self._streams_to_waiters.setdefault( + (stream_name, instance_name), [] + ) waiting_list.append((position, deferred)) waiting_list.sort(key=lambda t: t[0]) -- cgit 1.5.1 From cdea7c11d082e73606bea5d0462f7971e90d836c Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 20 Jan 2023 12:06:19 +0000 Subject: Faster joins: Avoid starting duplicate partial state syncs (#14844) Currently, we will try to start a new partial state sync every time we perform a remote join, which is undesirable if there is already one running for a given room. We intend to perform remote joins whenever additional local users wish to join a partial state room, so let's ensure that we do not start more than one concurrent partial state sync for any given room. ------------------------------------------------------------------------ There is a race condition where the homeserver leaves a room and later rejoins while the partial state sync from the previous membership is still running. There is no guarantee that the previous partial state sync will process the latest join, so we restart it if needed. Signed-off-by: Sean Quah --- changelog.d/14844.misc | 1 + synapse/handlers/federation.py | 106 +++++++++++++++++++++++++++++++++--- tests/handlers/test_federation.py | 112 +++++++++++++++++++++++++++++++++++++- 3 files changed, 210 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14844.misc (limited to 'synapse') diff --git a/changelog.d/14844.misc b/changelog.d/14844.misc new file mode 100644 index 0000000000..30ce866304 --- /dev/null +++ b/changelog.d/14844.misc @@ -0,0 +1 @@ +Add check to avoid starting duplicate partial state syncs. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eca75f1108..e386f77de6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -27,6 +27,7 @@ from typing import ( Iterable, List, Optional, + Set, Tuple, Union, ) @@ -171,12 +172,23 @@ class FederationHandler: self.third_party_event_rules = hs.get_third_party_event_rules() + # Tracks running partial state syncs by room ID. + # Partial state syncs currently only run on the main process, so it's okay to + # track them in-memory for now. + self._active_partial_state_syncs: Set[str] = set() + # Tracks partial state syncs we may want to restart. + # A dictionary mapping room IDs to (initial destination, other destinations) + # tuples. + self._partial_state_syncs_maybe_needing_restart: Dict[ + str, Tuple[Optional[str], Collection[str]] + ] = {} + # if this is the main process, fire off a background process to resume # any partial-state-resync operations which were in flight when we # were shut down. if not hs.config.worker.worker_app: run_as_background_process( - "resume_sync_partial_state_room", self._resume_sync_partial_state_room + "resume_sync_partial_state_room", self._resume_partial_state_room_sync ) @trace @@ -679,9 +691,7 @@ class FederationHandler: if ret.partial_state: # Kick off the process of asynchronously fetching the state for this # room. - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, + self._start_partial_state_room_sync( initial_destination=origin, other_destinations=ret.servers_in_room, room_id=room_id, @@ -1660,20 +1670,100 @@ class FederationHandler: # well. return None - async def _resume_sync_partial_state_room(self) -> None: + async def _resume_partial_state_room_sync(self) -> None: """Resumes resyncing of all partial-state rooms after a restart.""" assert not self.config.worker.worker_app partial_state_rooms = await self.store.get_partial_state_room_resync_info() for room_id, resync_info in partial_state_rooms.items(): - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, + self._start_partial_state_room_sync( initial_destination=resync_info.joined_via, other_destinations=resync_info.servers_in_room, room_id=room_id, ) + def _start_partial_state_room_sync( + self, + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + """Starts the background process to resync the state of a partial state room, + if it is not already running. + + Args: + initial_destination: the initial homeserver to pull the state from + other_destinations: other homeservers to try to pull the state from, if + `initial_destination` is unavailable + room_id: room to be resynced + """ + + async def _sync_partial_state_room_wrapper() -> None: + if room_id in self._active_partial_state_syncs: + # Another local user has joined the room while there is already a + # partial state sync running. This implies that there is a new join + # event to un-partial state. We might find ourselves in one of a few + # scenarios: + # 1. There is an existing partial state sync. The partial state sync + # un-partial states the new join event before completing and all is + # well. + # 2. Before the latest join, the homeserver was no longer in the room + # and there is an existing partial state sync from our previous + # membership of the room. The partial state sync may have: + # a) succeeded, but not yet terminated. The room will not be + # un-partial stated again unless we restart the partial state + # sync. + # b) failed, because we were no longer in the room and remote + # homeservers were refusing our requests, but not yet + # terminated. After the latest join, remote homeservers may + # start answering our requests again, so we should restart the + # partial state sync. + # In the cases where we would want to restart the partial state sync, + # the room would have the partial state flag when the partial state sync + # terminates. + self._partial_state_syncs_maybe_needing_restart[room_id] = ( + initial_destination, + other_destinations, + ) + return + + self._active_partial_state_syncs.add(room_id) + + try: + await self._sync_partial_state_room( + initial_destination=initial_destination, + other_destinations=other_destinations, + room_id=room_id, + ) + finally: + # Read the room's partial state flag while we still hold the claim to + # being the active partial state sync (so that another partial state + # sync can't come along and mess with it under us). + # Normally, the partial state flag will be gone. If it isn't, then we + # may find ourselves in scenario 2a or 2b as described in the comment + # above, where we want to restart the partial state sync. + is_still_partial_state_room = await self.store.is_partial_state_room( + room_id + ) + self._active_partial_state_syncs.remove(room_id) + + if room_id in self._partial_state_syncs_maybe_needing_restart: + ( + restart_initial_destination, + restart_other_destinations, + ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id) + + if is_still_partial_state_room: + self._start_partial_state_room_sync( + initial_destination=restart_initial_destination, + other_destinations=restart_other_destinations, + room_id=room_id, + ) + + run_as_background_process( + desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper + ) + async def _sync_partial_state_room( self, initial_destination: Optional[str], diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index cedbb9fafc..c1558c40c3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import cast +from typing import Collection, Optional, cast from unittest import TestCase from unittest.mock import Mock, patch +from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes @@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): f"Stale partial-stated room flag left over for {room_id} after a" f" failed do_invite_join!", ) + + def test_duplicate_partial_state_room_syncs(self) -> None: + """ + Tests that concurrent partial state syncs are not started for the same room. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Try to start another partial state sync. + # Nothing should happen. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # End the partial state sync + is_partial_state = False + end_sync.callback(None) + + # The partial state sync should not be restarted. + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # The next attempt to start the partial state sync should work. + is_partial_state = True + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + def test_partial_state_room_sync_restart(self) -> None: + """ + Tests that partial state syncs are restarted when a second partial state sync + was deduplicated and the first partial state sync fails. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Fail the partial state sync. + # The partial state sync should not be restarted. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Start the partial state sync again. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Deduplicate another partial state sync. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Fail the partial state sync. + # It should restart with the latest parameters. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 3) + mock_sync_partial_state_room.assert_called_with( + initial_destination="hs3", + other_destinations=["hs2"], + room_id="room_id", + ) -- cgit 1.5.1 From 65d03866936adb144631d263a8539a2cb060fd43 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2023 18:02:18 +0000 Subject: Always notify replication when a stream advances (#14877) This ensures that all other workers are told about stream updates in a timely manner, without having to remember to manually poke replication. --- changelog.d/14877.misc | 1 + synapse/_scripts/synapse_port_db.py | 4 +++ synapse/notifier.py | 31 +++++++++++++++++++---- synapse/server.py | 6 ++++- synapse/storage/databases/main/account_data.py | 2 ++ synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/deviceinbox.py | 3 ++- synapse/storage/databases/main/devices.py | 1 + synapse/storage/databases/main/end_to_end_keys.py | 5 +++- synapse/storage/databases/main/events_worker.py | 10 +++++++- synapse/storage/databases/main/presence.py | 3 ++- synapse/storage/databases/main/push_rule.py | 1 + synapse/storage/databases/main/pusher.py | 1 + synapse/storage/databases/main/receipts.py | 2 ++ synapse/storage/databases/main/room.py | 6 ++++- synapse/storage/util/id_generators.py | 26 +++++++++++++++++-- tests/module_api/test_api.py | 3 +++ tests/replication/tcp/test_handler.py | 23 +++++------------ tests/storage/test_id_generators.py | 4 +++ 19 files changed, 104 insertions(+), 29 deletions(-) create mode 100644 changelog.d/14877.misc (limited to 'synapse') diff --git a/changelog.d/14877.misc b/changelog.d/14877.misc new file mode 100644 index 0000000000..4e9c3fa33f --- /dev/null +++ b/changelog.d/14877.misc @@ -0,0 +1 @@ +Always notify replication when a stream advances automatically. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index c463b60b26..5e137dbbf7 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -51,6 +51,7 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) +from synapse.notifier import ReplicationNotifier from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -260,6 +261,9 @@ class MockHomeserver: def should_send_federation(self) -> bool: return False + def get_replication_notifier(self) -> ReplicationNotifier: + return ReplicationNotifier() + class Porter: def __init__( diff --git a/synapse/notifier.py b/synapse/notifier.py index 26b97cf766..28f0d4a25a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -226,8 +226,7 @@ class Notifier: self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] - # Called when there are new things to stream over replication - self.replication_callbacks: List[Callable[[], None]] = [] + self._replication_notifier = hs.get_replication_notifier() self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = [] self._federation_client = hs.get_federation_http_client() @@ -279,7 +278,7 @@ class Notifier: it needs to do any asynchronous work, a background thread should be started and wrapped with run_as_background_process. """ - self.replication_callbacks.append(cb) + self._replication_notifier.add_replication_callback(cb) def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None: """Add a callback that will be called when a user joins a room. @@ -741,8 +740,7 @@ class Notifier: def notify_replication(self) -> None: """Notify the any replication listeners that there's a new event""" - for cb in self.replication_callbacks: - cb() + self._replication_notifier.notify_replication() def notify_user_joined_room(self, event_id: str, room_id: str) -> None: for cb in self._new_join_in_room_callbacks: @@ -759,3 +757,26 @@ class Notifier: # Tell the federation client about the fact the server is back up, so # that any in flight requests can be immediately retried. self._federation_client.wake_destination(server) + + +@attr.s(auto_attribs=True) +class ReplicationNotifier: + """Tracks callbacks for things that need to know about stream changes. + + This is separate from the notifier to avoid circular dependencies. + """ + + _replication_callbacks: List[Callable[[], None]] = attr.Factory(list) + + def add_replication_callback(self, cb: Callable[[], None]) -> None: + """Add a callback that will be called when some new data is available. + Callback is not given any arguments. It should *not* return a Deferred - if + it needs to do any asynchronous work, a background thread should be started and + wrapped with run_as_background_process. + """ + self._replication_callbacks.append(cb) + + def notify_replication(self) -> None: + """Notify the any replication listeners that there's a new event""" + for cb in self._replication_callbacks: + cb() diff --git a/synapse/server.py b/synapse/server.py index f4ab94c4f3..9d6d268f49 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi -from synapse.notifier import Notifier +from synapse.notifier import Notifier, ReplicationNotifier from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler @@ -389,6 +389,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_notifier(self) -> Notifier: return Notifier(self) + @cache_in_self + def get_replication_notifier(self) -> ReplicationNotifier: + return ReplicationNotifier() + @cache_in_self def get_auth(self) -> Auth: return Auth(self) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 881d7089db..8a359d7eb8 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="account_data", instance_name=self._instance_name, tables=[ @@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # SQLite). self._account_data_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2179a8bf59..5b66431691 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + notifier=hs.get_replication_notifier(), stream_name="caches", instance_name=hs.get_instance_name(), tables=[ diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 713be91c5d..8e61aba454 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], @@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" + db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id" ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index cd186c8472..903606fb46 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # class below that is used on the main process. self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "device_lists_stream", "stream_id", extra_tables=[ diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4c691642e2..c4ac6c33ba 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" + db_conn, + hs.get_replication_notifier(), + "e2e_cross_signing_keys", + "stream_id", ) async def set_e2e_device_keys( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d150fa8a94..d8a8bcafb6 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="events", instance_name=hs.get_instance_name(), tables=[("events", "instance_name", "stream_ordering")], @@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore): self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="backfill", instance_name=hs.get_instance_name(), tables=[("events", "instance_name", "stream_ordering")], @@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore): # SQLite). self._stream_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "events", "stream_ordering", is_writer=hs.get_instance_name() in hs.config.worker.writers.events, ) self._backfill_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "events", "stream_ordering", step=-1, @@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore): self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="un_partial_stated_event_stream", instance_name=hs.get_instance_name(), tables=[ @@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore): ) else: self._un_partial_stated_events_stream_id_gen = StreamIdGenerator( - db_conn, "un_partial_stated_event_stream", "stream_id" + db_conn, + hs.get_replication_notifier(), + "un_partial_stated_event_stream", + "stream_id", ) def get_un_partial_stated_events_token(self) -> int: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 7b60815043..beb210f8ee 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) self._presence_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="presence_stream", instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], @@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) ) else: self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" + db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id" ) self.hs = hs diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 03182887d1..14ca167b34 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -118,6 +118,7 @@ class PushRulesWorkerStore( # class below that is used on the main process. self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "push_rules_stream", "stream_id", is_writer=hs.config.worker.worker_app is None, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 7f24a3b6ec..df53e726e6 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore): # class below that is used on the main process. self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 86f5bce5f0..3468f354e6 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], @@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # SQLite). self._receipts_id_gen = StreamIdGenerator( db_conn, + hs.get_replication_notifier(), "receipts_linearized", "stream_id", is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 78906a5e1d..7264a33cd4 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + notifier=hs.get_replication_notifier(), stream_name="un_partial_stated_room_stream", instance_name=self._instance_name, tables=[ @@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) else: self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator( - db_conn, "un_partial_stated_room_stream", "stream_id" + db_conn, + hs.get_replication_notifier(), + "un_partial_stated_room_stream", + "stream_id", ) async def store_room( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 8670ffbfa3..9adff3f4f5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,6 +20,7 @@ from collections import OrderedDict from contextlib import contextmanager from types import TracebackType from typing import ( + TYPE_CHECKING, AsyncContextManager, ContextManager, Dict, @@ -49,6 +50,9 @@ from synapse.storage.database import ( from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator +if TYPE_CHECKING: + from synapse.notifier import ReplicationNotifier + logger = logging.getLogger(__name__) @@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator): def __init__( self, db_conn: LoggingDatabaseConnection, + notifier: "ReplicationNotifier", table: str, column: str, extra_tables: Iterable[Tuple[str, str]] = (), @@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator): # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() + self._notifier = notifier + def advance(self, instance_name: str, new_id: int) -> None: # Advance should never be called on a writer instance, only over replication if self._is_writer: @@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator): with self._lock: self._unfinished_ids.pop(next_id) + self._notifier.notify_replication() + return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: @@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator): for next_id in next_ids: self._unfinished_ids.pop(next_id) + self._notifier.notify_replication() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: @@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self, db_conn: LoggingDatabaseConnection, db: DatabasePool, + notifier: "ReplicationNotifier", stream_name: str, instance_name: str, tables: List[Tuple[str, str, str]], @@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): positive: bool = True, ) -> None: self._db = db + self._notifier = notifier self._stream_name = stream_name self._instance_name = instance_name self._positive = positive @@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, # controls the return type. If `None` or omitted, the context manager yields # a single integer stream_id; otherwise it yields a list of stream_ids. - return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) + return cast( + AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier) + ) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: # If we have a list of instances that are allowed to write to this @@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): raise Exception("Tried to allocate stream ID on non-writer") # Cast safety: see get_next. - return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) + return cast( + AsyncContextManager[List[int]], + _MultiWriterCtxManager(self, self._notifier, n), + ) def get_next_txn(self, txn: LoggingTransaction) -> int: """ @@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + txn.call_after(self._notifier.notify_replication) # Update the `stream_positions` table with newly updated stream # ID (unless self._writers is not set in which case we don't @@ -787,6 +806,7 @@ class _MultiWriterCtxManager: """Async context manager returned by MultiWriterIdGenerator""" id_gen: MultiWriterIdGenerator + notifier: "ReplicationNotifier" multiple_ids: Optional[int] = None stream_ids: List[int] = attr.Factory(list) @@ -814,6 +834,8 @@ class _MultiWriterCtxManager: for i in self.stream_ids: self.id_gen._mark_id_as_finished(i) + self.notifier.notify_replication() + if exc_type is not None: return False diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9919938e80..8f88c0117d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase): self.module_api.send_local_online_presence_to([remote_user_id]) ) + # We don't always send out federation immediately, so we advance the clock. + self.reactor.advance(1000) + # Check that a presence update was sent as part of a federation transaction found_update = False calls = ( diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index 555922409d..6e4055cc21 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -14,7 +14,7 @@ from twisted.internet import defer -from synapse.replication.tcp.commands import PositionCommand, RdataCommand +from synapse.replication.tcp.commands import PositionCommand from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): next_token = self.get_success(ctx.__aenter__()) self.get_success(ctx.__aexit__(None, None, None)) - cmd_handler.send_command( - RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) - ) - self.replicate() - self.get_success( data_handler.wait_for_stream_position("worker1", "caches", next_token) ) - # `wait_for_stream_position` should only return once master receives an - # RDATA from the worker - ctx = cache_id_gen.get_next() - next_token = self.get_success(ctx.__aenter__()) - self.get_success(ctx.__aexit__(None, None, None)) + # `wait_for_stream_position` should only return once master receives a + # notification that `next_token` has persisted. + ctx_worker1 = cache_id_gen.get_next() + next_token = self.get_success(ctx_worker1.__aenter__()) d = defer.ensureDeferred( data_handler.wait_for_stream_position("worker1", "caches", next_token) @@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): ) self.assertFalse(d.called) - # ... but receiving the RDATA should - cmd_handler.send_command( - RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) - ) - self.replicate() + # ... but worker1 finishing (and so sending an update) should. + self.get_success(ctx_worker1.__aexit__(None, None, None)) self.assertTrue(d.called) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index ff9691c518..9174fb0964 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase): def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: return StreamIdGenerator( db_conn=conn, + notifier=self.hs.get_replication_notifier(), table="foobar", column="stream_id", ) @@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], @@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], @@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[ -- cgit 1.5.1 From 0ec12a37538d0df07d96cfc9cf5f5208f7453607 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2023 21:04:33 +0000 Subject: Reduce max time we wait for stream positions (#14881) Now that we wait for stream positions whenever we do a HTTP replication hit, we need to be less brutal in the case where we do timeout (as we have bugs around this). --- changelog.d/14881.misc | 1 + synapse/replication/http/_base.py | 2 -- synapse/replication/tcp/client.py | 21 +++++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) create mode 100644 changelog.d/14881.misc (limited to 'synapse') diff --git a/changelog.d/14881.misc b/changelog.d/14881.misc new file mode 100644 index 0000000000..be89d092b6 --- /dev/null +++ b/changelog.d/14881.misc @@ -0,0 +1 @@ +Reduce max time we wait for stream positions. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 709327b97f..908f3f1db7 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -352,7 +352,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): instance_name=instance_name, stream_name=stream_name, position=position, - raise_on_timeout=False, ) return result @@ -414,7 +413,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): instance_name=content[_STREAM_POSITION_KEY]["instance_name"], stream_name=stream_name, position=position, - raise_on_timeout=False, ) if self.CACHE: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 6e242c5749..493f616679 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # How long we allow callers to wait for replication updates before timing out. -_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 5 class DirectTcpReplicationClientFactory(ReconnectingClientFactory): @@ -326,7 +326,6 @@ class ReplicationDataHandler: instance_name: str, stream_name: str, position: int, - raise_on_timeout: bool = True, ) -> None: """Wait until this instance has received updates up to and including the given stream position. @@ -335,8 +334,6 @@ class ReplicationDataHandler: instance_name stream_name position - raise_on_timeout: Whether to raise an exception if we time out - waiting for the updates, or if we log an error and return. """ if instance_name == self._instance_name: @@ -365,19 +362,23 @@ class ReplicationDataHandler: # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): - logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + logger.info( + "Waiting for repl stream %r to reach %s (%s)", + stream_name, + position, + instance_name, + ) try: await make_deferred_yieldable(deferred) except defer.TimeoutError: logger.error("Timed out waiting for stream %s", stream_name) - - if raise_on_timeout: - raise - return logger.info( - "Finished waiting for repl stream %r to reach %s", stream_name, position + "Finished waiting for repl stream %r to reach %s (%s)", + stream_name, + position, + instance_name, ) def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: -- cgit 1.5.1 From d329a566df6ff2b635a375bf1b2c8ed3b2c9815d Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Sun, 22 Jan 2023 19:19:31 +0000 Subject: Faster joins: Fix incompatibility with restricted joins (#14882) * Avoid clearing out forward extremities when doing a second remote join When joining a restricted room where the local homeserver does not have a user able to issue invites, we perform a second remote join. We want to avoid clearing out forward extremities in this case because the forward extremities we have are up to date and clearing out forward extremities creates a window in which the room can get bricked if Synapse crashes. Signed-off-by: Sean Quah * Do a full join when doing a second remote join into a full state room We cannot persist a partial state join event into a joined full state room, so we perform a full state join for such rooms instead. As a future optimization, we could always perform a partial state join and compute or retrieve the full state ourselves if necessary. Signed-off-by: Sean Quah * Add lock around partial state flag for rooms Signed-off-by: Sean Quah * Preserve partial state info when doing a second partial state join Signed-off-by: Sean Quah * Add newsfile * Add a TODO(faster_joins) marker Signed-off-by: Sean Quah --- changelog.d/14882.bugfix | 1 + synapse/federation/federation_client.py | 5 + synapse/handlers/federation.py | 215 ++++++++++++++++++++------------ 3 files changed, 140 insertions(+), 81 deletions(-) create mode 100644 changelog.d/14882.bugfix (limited to 'synapse') diff --git a/changelog.d/14882.bugfix b/changelog.d/14882.bugfix new file mode 100644 index 0000000000..1fda344361 --- /dev/null +++ b/changelog.d/14882.bugfix @@ -0,0 +1 @@ +Faster joins: Fix incompatibility with joins into restricted rooms where no local users have the ability to invite. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 15a9a88302..f185b6c1f9 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1157,6 +1157,11 @@ class FederationClient(FederationBase): "members_omitted was set, but no servers were listed in the room" ) + if response.members_omitted and not partial_state: + raise InvalidResponseError( + "members_omitted was set, but we asked for full state" + ) + return SendJoinResult( event=event, state=signed_state, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e386f77de6..2123ace8a6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -48,7 +48,6 @@ from synapse.api.errors import ( FederationError, FederationPullAttemptBackoffError, HttpResponseException, - LimitExceededError, NotFoundError, RequestSendFailed, SynapseError, @@ -182,6 +181,12 @@ class FederationHandler: self._partial_state_syncs_maybe_needing_restart: Dict[ str, Tuple[Optional[str], Collection[str]] ] = {} + # A lock guarding the partial state flag for rooms. + # When the lock is held for a given room, no other concurrent code may + # partial state or un-partial state the room. + self._is_partial_state_room_linearizer = Linearizer( + name="_is_partial_state_room_linearizer" + ) # if this is the main process, fire off a background process to resume # any partial-state-resync operations which were in flight when we @@ -599,7 +604,23 @@ class FederationHandler: self._federation_event_handler.room_queues[room_id] = [] - await self._clean_room_for_join(room_id) + is_host_joined = await self.store.is_host_joined(room_id, self.server_name) + + if not is_host_joined: + # We may have old forward extremities lying around if the homeserver left + # the room completely in the past. Clear them out. + # + # Note that this check-then-clear is subject to races where + # * the homeserver is in the room and stops being in the room just after + # the check. We won't reset the forward extremities, but that's okay, + # since they will be almost up to date. + # * the homeserver is not in the room and starts being in the room just + # after the check. This can't happen, since `RoomMemberHandler` has a + # linearizer lock which prevents concurrent remote joins into the same + # room. + # In short, the races either have an acceptable outcome or should be + # impossible. + await self._clean_room_for_join(room_id) try: # Try the host we successfully got a response to /make_join/ @@ -611,91 +632,115 @@ class FederationHandler: except ValueError: pass - ret = await self.federation_client.send_join( - host_list, event, room_version_obj - ) - - event = ret.event - origin = ret.origin - state = ret.state - auth_chain = ret.auth_chain - auth_chain.sort(key=lambda e: e.depth) - - logger.debug("do_invite_join auth_chain: %s", auth_chain) - logger.debug("do_invite_join state: %s", state) - - logger.debug("do_invite_join event: %s", event) + async with self._is_partial_state_room_linearizer.queue(room_id): + already_partial_state_room = await self.store.is_partial_state_room( + room_id + ) - # if this is the first time we've joined this room, it's time to add - # a row to `rooms` with the correct room version. If there's already a - # row there, we should override it, since it may have been populated - # based on an invite request which lied about the room version. - # - # federation_client.send_join has already checked that the room - # version in the received create event is the same as room_version_obj, - # so we can rely on it now. - # - await self.store.upsert_room_on_join( - room_id=room_id, - room_version=room_version_obj, - state_events=state, - ) + ret = await self.federation_client.send_join( + host_list, + event, + room_version_obj, + # Perform a full join when we are already in the room and it is a + # full state room, since we are not allowed to persist a partial + # state join event in a full state room. In the future, we could + # optimize this by always performing a partial state join and + # computing the state ourselves or retrieving it from the remote + # homeserver if necessary. + # + # There's a race where we leave the room, then perform a full join + # anyway. This should end up being fast anyway, since we would + # already have the full room state and auth chain persisted. + partial_state=not is_host_joined or already_partial_state_room, + ) - if ret.partial_state: - # Mark the room as having partial state. - # The background process is responsible for unmarking this flag, - # even if the join fails. - await self.store.store_partial_state_room( + event = ret.event + origin = ret.origin + state = ret.state + auth_chain = ret.auth_chain + auth_chain.sort(key=lambda e: e.depth) + + logger.debug("do_invite_join auth_chain: %s", auth_chain) + logger.debug("do_invite_join state: %s", state) + + logger.debug("do_invite_join event: %s", event) + + # if this is the first time we've joined this room, it's time to add + # a row to `rooms` with the correct room version. If there's already a + # row there, we should override it, since it may have been populated + # based on an invite request which lied about the room version. + # + # federation_client.send_join has already checked that the room + # version in the received create event is the same as room_version_obj, + # so we can rely on it now. + # + await self.store.upsert_room_on_join( room_id=room_id, - servers=ret.servers_in_room, - device_lists_stream_id=self.store.get_device_stream_token(), - joined_via=origin, + room_version=room_version_obj, + state_events=state, ) - try: - max_stream_id = ( - await self._federation_event_handler.process_remote_join( - origin, - room_id, - auth_chain, - state, - event, - room_version_obj, - partial_state=ret.partial_state, + if ret.partial_state and not already_partial_state_room: + # Mark the room as having partial state. + # The background process is responsible for unmarking this flag, + # even if the join fails. + # TODO(faster_joins): + # We may want to reset the partial state info if it's from an + # old, failed partial state join. + # https://github.com/matrix-org/synapse/issues/13000 + await self.store.store_partial_state_room( + room_id=room_id, + servers=ret.servers_in_room, + device_lists_stream_id=self.store.get_device_stream_token(), + joined_via=origin, ) - ) - except PartialStateConflictError as e: - # The homeserver was already in the room and it is no longer partial - # stated. We ought to be doing a local join instead. Turn the error into - # a 429, as a hint to the client to try again. - # TODO(faster_joins): `_should_perform_remote_join` suggests that we may - # do a remote join for restricted rooms even if we have full state. - logger.error( - "Room %s was un-partial stated while processing remote join.", - room_id, - ) - raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0) - else: - # Record the join event id for future use (when we finish the full - # join). We have to do this after persisting the event to keep foreign - # key constraints intact. - if ret.partial_state: - await self.store.write_partial_state_rooms_join_event_id( - room_id, event.event_id + + try: + max_stream_id = ( + await self._federation_event_handler.process_remote_join( + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, + ) ) - finally: - # Always kick off the background process that asynchronously fetches - # state for the room. - # If the join failed, the background process is responsible for - # cleaning up — including unmarking the room as a partial state room. - if ret.partial_state: - # Kick off the process of asynchronously fetching the state for this - # room. - self._start_partial_state_room_sync( - initial_destination=origin, - other_destinations=ret.servers_in_room, - room_id=room_id, + except PartialStateConflictError: + # This should be impossible, since we hold the lock on the room's + # partial statedness. + logger.error( + "Room %s was un-partial stated while processing remote join.", + room_id, ) + raise + else: + # Record the join event id for future use (when we finish the full + # join). We have to do this after persisting the event to keep + # foreign key constraints intact. + if ret.partial_state and not already_partial_state_room: + # TODO(faster_joins): + # We may want to reset the partial state info if it's from + # an old, failed partial state join. + # https://github.com/matrix-org/synapse/issues/13000 + await self.store.write_partial_state_rooms_join_event_id( + room_id, event.event_id + ) + finally: + # Always kick off the background process that asynchronously fetches + # state for the room. + # If the join failed, the background process is responsible for + # cleaning up — including unmarking the room as a partial state + # room. + if ret.partial_state: + # Kick off the process of asynchronously fetching the state for + # this room. + self._start_partial_state_room_sync( + initial_destination=origin, + other_destinations=ret.servers_in_room, + room_id=room_id, + ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. @@ -1778,6 +1823,12 @@ class FederationHandler: `initial_destination` is unavailable room_id: room to be resynced """ + # Assume that we run on the main process for now. + # TODO(faster_joins,multiple workers) + # When moving the sync to workers, we need to ensure that + # * `_start_partial_state_room_sync` still prevents duplicate resyncs + # * `_is_partial_state_room_linearizer` correctly guards partial state flags + # for rooms between the workers doing remote joins and resync. assert not self.config.worker.worker_app # TODO(faster_joins): do we need to lock to avoid races? What happens if other @@ -1815,8 +1866,10 @@ class FederationHandler: logger.info("Handling any pending device list updates") await self._device_handler.handle_room_un_partial_stated(room_id) - logger.info("Clearing partial-state flag for %s", room_id) - success = await self.store.clear_partial_state_room(room_id) + async with self._is_partial_state_room_linearizer.queue(room_id): + logger.info("Clearing partial-state flag for %s", room_id) + success = await self.store.clear_partial_state_room(room_id) + if success: logger.info("State resync complete for %s", room_id) self._storage_controllers.state.notify_room_un_partial_stated( -- cgit 1.5.1 From 22cc93afe38d34c859d8863a99996e7e72ca1733 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Sun, 22 Jan 2023 21:10:11 +0000 Subject: Enable Faster Remote Room Joins against worker-mode Synapse. (#14752) * Enable Complement tests for Faster Remote Room Joins on worker-mode * (dangerous) Add an override to allow Complement to use FRRJ under workers * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Fix race where we didn't send out replication notification * MORE HACKS * Fix get_un_partial_stated_rooms_token to take instance_name * Fix bad merge * Remove warning * Correctly advance un_partial_stated_room_stream * Fix merge * Add another notify_replication * Fixups * Create a separate ReplicationNotifier * Fix test * Fix portdb * Create a separate ReplicationNotifier * Fix test * Fix portdb * Fix presence test * Newsfile * Apply suggestions from code review * Update changelog.d/14752.misc Co-authored-by: Erik Johnston * lint Signed-off-by: Olivier Wilkinson (reivilibre) Co-authored-by: Erik Johnston --- changelog.d/14752.misc | 1 + docker/complement/conf/workers-shared-extra.yaml.j2 | 2 -- scripts-dev/complement.sh | 11 ++++------- synapse/app/generic_worker.py | 7 ------- synapse/handlers/device.py | 2 ++ synapse/handlers/federation.py | 7 ++++--- synapse/replication/tcp/streams/partial_state.py | 7 ++----- synapse/storage/databases/main/events_worker.py | 13 ++++++++----- synapse/storage/databases/main/room.py | 19 ++++++++++++------- synapse/storage/databases/main/state.py | 2 ++ 10 files changed, 35 insertions(+), 36 deletions(-) create mode 100644 changelog.d/14752.misc (limited to 'synapse') diff --git a/changelog.d/14752.misc b/changelog.d/14752.misc new file mode 100644 index 0000000000..1f9675c53b --- /dev/null +++ b/changelog.d/14752.misc @@ -0,0 +1 @@ +Enable Complement tests for Faster Remote Room Joins against worker-mode Synapse. \ No newline at end of file diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 281157846a..63acf86a46 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -94,10 +94,8 @@ allow_device_name_lookup_over_federation: true experimental_features: # Enable history backfilling support msc2716_enabled: true - {% if not workers_in_use %} # client-side support for partial state in /send_join responses faster_joins: true - {% endif %} # Enable support for polls msc3381_polls_enabled: true # Enable deleting device-specific notification settings stored in account data diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index a183653d52..e72d96fd16 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -190,7 +190,7 @@ fi extra_test_args=() -test_tags="synapse_blacklist,msc3787,msc3874,msc3890,msc3391,msc3930" +test_tags="synapse_blacklist,msc3787,msc3874,msc3890,msc3391,msc3930,faster_joins" # All environment variables starting with PASS_ will be shared. # (The prefix is stripped off before reaching the container.) @@ -223,12 +223,9 @@ else export PASS_SYNAPSE_COMPLEMENT_DATABASE=sqlite fi - # We only test faster room joins on monoliths, because they are purposefully - # being developed without worker support to start with. - # - # The tests for importing historical messages (MSC2716) also only pass with monoliths, - # currently. - test_tags="$test_tags,faster_joins,msc2716" + # The tests for importing historical messages (MSC2716) + # only pass with monoliths, currently. + test_tags="$test_tags,msc2716" fi diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 8108b1e98f..946f3a3807 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -282,13 +282,6 @@ def start(config_options: List[str]) -> None: "synapse.app.user_dir", ) - if config.experimental.faster_joins_enabled: - raise ConfigError( - "You have enabled the experimental `faster_joins` config option, but it is " - "not compatible with worker deployments yet. Please disable `faster_joins` " - "or run Synapse as a single process deployment instead." - ) - synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 0640ea79a0..58180ae2fa 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -974,6 +974,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler + self._notifier = hs.get_notifier() self._remote_edu_linearizer = Linearizer(name="remote_device_list") @@ -1054,6 +1055,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): user_id, device_id, ) + self._notifier.notify_replication() room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2123ace8a6..7620245e26 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1870,14 +1870,15 @@ class FederationHandler: logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) + # Poke the notifier so that other workers see the write to + # the un-partial-stated rooms stream. + self._notifier.notify_replication() + if success: logger.info("State resync complete for %s", room_id) self._storage_controllers.state.notify_room_un_partial_stated( room_id ) - # Poke the notifier so that other workers see the write to - # the un-partial-stated rooms stream. - self._notifier.notify_replication() # TODO(faster_joins) update room stats and user directory? # https://github.com/matrix-org/synapse/issues/12814 diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py index b5a2ae74b6..a8ce5ffd72 100644 --- a/synapse/replication/tcp/streams/partial_state.py +++ b/synapse/replication/tcp/streams/partial_state.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING import attr from synapse.replication.tcp.streams import Stream -from synapse.replication.tcp.streams._base import current_token_without_instance if TYPE_CHECKING: from synapse.server import HomeServer @@ -42,8 +41,7 @@ class UnPartialStatedRoomStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - # TODO(faster_joins, multiple writers): we need to account for instance names - current_token_without_instance(store.get_un_partial_stated_rooms_token), + store.get_un_partial_stated_rooms_token, store.get_un_partial_stated_rooms_from_stream, ) @@ -70,7 +68,6 @@ class UnPartialStatedEventStream(Stream): store = hs.get_datastores().main super().__init__( hs.get_instance_name(), - # TODO(faster_joins, multiple writers): we need to account for instance names - current_token_without_instance(store.get_un_partial_stated_events_token), + store.get_un_partial_stated_events_token, store.get_un_partial_stated_events_from_stream, ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d8a8bcafb6..24127d0364 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -322,11 +322,12 @@ class EventsWorkerStore(SQLBaseStore): "stream_id", ) - def get_un_partial_stated_events_token(self) -> int: - # TODO(faster_joins, multiple writers): This is inappropriate if there are multiple - # writers because workers that don't write often will hold all - # readers up. - return self._un_partial_stated_events_stream_id_gen.get_current_token() + def get_un_partial_stated_events_token(self, instance_name: str) -> int: + return ( + self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( + instance_name + ) + ) async def get_un_partial_stated_events_from_stream( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -416,6 +417,8 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) + elif stream_name == UnPartialStatedEventStream.NAME: + self._un_partial_stated_events_stream_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) async def have_censored_event(self, event_id: str) -> bool: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7264a33cd4..6a65b2a89b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -43,6 +43,7 @@ from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase +from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -144,6 +145,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): "stream_id", ) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == UnPartialStatedRoomStream.NAME: + self._un_partial_stated_rooms_stream_id_gen.advance(instance_name, token) + return super().process_replication_position(stream_name, instance_name, token) + async def store_room( self, room_id: str, @@ -1281,13 +1289,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) return result["join_event_id"], result["device_lists_stream_id"] - def get_un_partial_stated_rooms_token(self) -> int: - # TODO(faster_joins, multiple writers): This is inappropriate if there - # are multiple writers because workers that don't write often will - # hold all readers up. - # (See `MultiWriterIdGenerator.get_persisted_upto_position` for an - # explanation.) - return self._un_partial_stated_rooms_stream_id_gen.get_current_token() + def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: + return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( + instance_name + ) async def get_un_partial_stated_rooms_from_stream( self, instance_name: str, last_id: int, current_id: int, limit: int diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index f32cbb2dec..ba325d390b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -95,6 +95,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) self._get_state_group_for_event.invalidate((row.event_id,)) + self.is_partial_state_event.invalidate((row.event_id,)) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -485,6 +486,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "rejection_status_changed": rejection_status_changed, }, ) + txn.call_after(self.hs.get_notifier().on_new_replication_data) class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): -- cgit 1.5.1 From 2ec9c58496e2138cbc4364aba238997c393d5308 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 23 Jan 2023 10:31:36 +0000 Subject: Faster joins: Update room stats and the user directory on workers when finishing join (#14874) * Faster joins: Update room stats and user directory on workers when done When finishing a partial state join to a room, we update the current state of the room without persisting additional events. Workers receive notice of the current state update over replication, but neglect to wake the room stats and user directory updaters, which then get incidentally triggered the next time an event is persisted or an unrelated event persister sends out a stream position update. We wake the room stats and user directory updaters at the appropriate time in this commit. Part of #12814 and #12815. Signed-off-by: Sean Quah * fixup comment Signed-off-by: Sean Quah --- changelog.d/14874.bugfix | 1 + synapse/handlers/federation.py | 7 ++++--- synapse/replication/tcp/client.py | 6 ++++++ synapse/storage/controllers/state.py | 2 -- 4 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14874.bugfix (limited to 'synapse') diff --git a/changelog.d/14874.bugfix b/changelog.d/14874.bugfix new file mode 100644 index 0000000000..91ae2ea9bd --- /dev/null +++ b/changelog.d/14874.bugfix @@ -0,0 +1 @@ +Faster joins: Fix a bug in worker deployments where the room stats and user directory would not get updated when finishing a fast join until another event is sent or received. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7620245e26..3217127865 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1880,9 +1880,10 @@ class FederationHandler: room_id ) - # TODO(faster_joins) update room stats and user directory? - # https://github.com/matrix-org/synapse/issues/12814 - # https://github.com/matrix-org/synapse/issues/12815 + # Poke the notifier so that other workers see the write to + # the un-partial-stated rooms stream. + self._notifier.notify_replication() + return # we raced against more events arriving with partial state. Go round diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 493f616679..2a9cb499a4 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -207,6 +207,12 @@ class ReplicationDataHandler: # we don't need to optimise this for multiple rows. for row in rows: if row.type != EventsStreamEventRow.TypeId: + # The row's data is an `EventsStreamCurrentStateRow`. + # When we recompute the current state of a room based on forward + # extremities (see `update_current_state`), no new events are + # persisted, so we must poke the replication callbacks ourselves. + # This functionality is used when finishing up a partial state join. + self.notifier.notify_replication() continue assert isinstance(row, EventsStreamRow) assert isinstance(row.data, EventsStreamEventRow) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 26d79c6e62..2045169b9a 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -493,8 +493,6 @@ class StateStorageController: up to date. """ # FIXME(faster_joins): what do we do here? - # https://github.com/matrix-org/synapse/issues/12814 - # https://github.com/matrix-org/synapse/issues/12815 # https://github.com/matrix-org/synapse/issues/13008 return await self.stores.main.get_partial_current_state_deltas( -- cgit 1.5.1 From 82d3efa3124f771579ba07553904f88625c443b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Jan 2023 06:36:20 -0500 Subject: Skip processing stats for broken rooms. (#14873) * Skip processing stats for broken rooms. * Newsfragment * Use a custom exception. --- changelog.d/14873.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 6 +- synapse/storage/databases/main/stats.py | 13 +++- tests/storage/databases/main/test_room.py | 88 +++++++++++++++---------- 4 files changed, 72 insertions(+), 36 deletions(-) create mode 100644 changelog.d/14873.bugfix (limited to 'synapse') diff --git a/changelog.d/14873.bugfix b/changelog.d/14873.bugfix new file mode 100644 index 0000000000..9b058576cd --- /dev/null +++ b/changelog.d/14873.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the `populate_room_stats` background job could fail on broken rooms. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 24127d0364..f42af34a2f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -110,6 +110,10 @@ event_fetch_ongoing_gauge = Gauge( ) +class InvalidEventError(Exception): + """The event retrieved from the database is invalid and cannot be used.""" + + @attr.s(slots=True, auto_attribs=True) class EventCacheEntry: event: EventBase @@ -1310,7 +1314,7 @@ class EventsWorkerStore(SQLBaseStore): # invites, so just accept it for all membership events. # if d["type"] != EventTypes.Member: - raise Exception( + raise InvalidEventError( "Room %s for event %s is unknown" % (d["room_id"], event_id) ) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 356d4ca788..0c1cbd540d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -29,6 +29,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.events_worker import InvalidEventError from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -554,7 +555,17 @@ class StatsStore(StateDeltasStore): "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] + try: + state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] + except InvalidEventError as e: + # If an exception occurs fetching events then the room is broken; + # skip process it to avoid being stuck on a room. + logger.warning( + "Failed to fetch events for room %s, skipping stats calculation: %r.", + room_id, + e, + ) + return room_state: Dict[str, Union[None, bool, str]] = { "join_rules": None, diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7d961fac64..3108ca3444 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -40,9 +40,23 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): self.token = self.login("foo", "pass") def _generate_room(self) -> str: - room_id = self.helper.create_room_as(self.user_id, tok=self.token) + """Create a room and return the room ID.""" + return self.helper.create_room_as(self.user_id, tok=self.token) - return room_id + def run_background_updates(self, update_name: str) -> None: + """Insert and run the background update.""" + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": update_name, "progress_json": "{}"}, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() def test_background_populate_rooms_creator_column(self) -> None: """Test that the background update to populate the rooms creator column @@ -71,22 +85,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) self.assertEqual(room_creator_before, None) - # Insert and run the background update. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, - "progress_json": "{}", - }, - ) - ) - - # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db_pool.updates._all_done = False - - # Now let's actually drive the updates to completion - self.wait_for_background_updates() + self.run_background_updates(_BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN) # Make sure the background update filled in the room creator room_creator_after = self.get_success( @@ -137,22 +136,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) ) - # Insert and run the background update - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, - "progress_json": "{}", - }, - ) - ) - - # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db_pool.updates._all_done = False - - # Now let's actually drive the updates to completion - self.wait_for_background_updates() + self.run_background_updates(_BackgroundUpdates.ADD_ROOM_TYPE_COLUMN) # Make sure the background update filled in the room type room_type_after = self.get_success( @@ -164,3 +148,39 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(room_type_after, RoomTypes.SPACE) + + def test_populate_stats_broken_rooms(self) -> None: + """Ensure that re-populating room stats skips broken rooms.""" + + # Create a good room. + good_room_id = self._generate_room() + + # Create a room and then break it by having no room version. + room_id = self._generate_room() + self.get_success( + self.store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"room_version": None}, + desc="test", + ) + ) + + # Nuke any current stats in the database. + self.get_success( + self.store.db_pool.simple_delete( + table="room_stats_state", keyvalues={"1": 1}, desc="test" + ) + ) + + self.run_background_updates("populate_stats_process_rooms") + + # Only the good room appears in the stats tables. + results = self.get_success( + self.store.db_pool.simple_select_onecol( + table="room_stats_state", + keyvalues={}, + retcol="room_id", + ) + ) + self.assertEqual(results, [good_room_id]) -- cgit 1.5.1 From 80d44060c99e87c84da72fdfcaa9a508d38a26b4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 23 Jan 2023 15:44:39 +0000 Subject: Faster joins: omit partial rooms from eager syncs until the resync completes (#14870) * Allow `AbstractSet` in `StrCollection` Or else frozensets are excluded. This will be useful in an upcoming commit where I plan to change a function that accepts `List[str]` to accept `StrCollection` instead. * `rooms_to_exclude` -> `rooms_to_exclude_globally` I am about to make use of this exclusion mechanism to exclude rooms for a specific user and a specific sync. This rename helps to clarify the distinction between the global config and the rooms to exclude for a specific sync. * Better function names for internal sync methods * Track a list of excluded rooms on SyncResultBuilder I plan to feed a list of partially stated rooms for this sync to ignore * Exclude partial state rooms during eager sync using the mechanism established in the previous commit * Track un-partial-state stream in sync tokens So that we can work out which rooms have become fully-stated during a given sync period. * Fix mutation of `@cached` return value This was fouling up a complement test added alongside this PR. Excluding a room would mean the set of forgotten rooms in the cache would be extended. This means that room could be erroneously considered forgotten in the future. Introduced in #12310, Synapse 1.57.0. I don't think this had any user-visible side effects (until now). * SyncResultBuilder: track rooms to force as newly joined Similar plan as before. We've omitted rooms from certain sync responses; now we establish the mechanism to reintroduce them into future syncs. * Read new field, to present rooms as newly joined * Force un-partial-stated rooms to be newly-joined for eager incremental syncs only, provided they're still fully stated * Notify user stream listeners to wake up long polling syncs * Changelog * Typo fix Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Unnecessary list cast Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Rephrase comment Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Another comment Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Fixup merge(?) * Poke notifier when receiving un-partial-stated msg over replication * Fixup merge whoops Thanks MV :) Co-authored-by: Mathieu Velen Co-authored-by: Mathieu Velten Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14870.feature | 1 + synapse/handlers/federation.py | 15 +++---- synapse/handlers/sync.py | 65 +++++++++++++++++++++++----- synapse/notifier.py | 26 +++++++++++ synapse/replication/tcp/client.py | 1 + synapse/storage/databases/main/relations.py | 1 + synapse/storage/databases/main/room.py | 47 +++++++++++++++++--- synapse/storage/databases/main/roommember.py | 19 +++++--- synapse/streams/events.py | 6 +++ synapse/types/__init__.py | 15 ++++--- tests/rest/admin/test_room.py | 4 +- tests/rest/client/test_rooms.py | 10 ++--- tests/rest/client/test_sync.py | 4 +- 13 files changed, 170 insertions(+), 44 deletions(-) create mode 100644 changelog.d/14870.feature (limited to 'synapse') diff --git a/changelog.d/14870.feature b/changelog.d/14870.feature new file mode 100644 index 0000000000..44f701d1c9 --- /dev/null +++ b/changelog.d/14870.feature @@ -0,0 +1 @@ +Faster joins: allow non-lazy-loading ("eager") syncs to complete after a partial join by omitting partial state rooms until they become fully stated. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3217127865..233f8c113d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1868,22 +1868,17 @@ class FederationHandler: async with self._is_partial_state_room_linearizer.queue(room_id): logger.info("Clearing partial-state flag for %s", room_id) - success = await self.store.clear_partial_state_room(room_id) + new_stream_id = await self.store.clear_partial_state_room(room_id) - # Poke the notifier so that other workers see the write to - # the un-partial-stated rooms stream. - self._notifier.notify_replication() - - if success: + if new_stream_id is not None: logger.info("State resync complete for %s", room_id) self._storage_controllers.state.notify_room_un_partial_stated( room_id ) - # Poke the notifier so that other workers see the write to - # the un-partial-stated rooms stream. - self._notifier.notify_replication() - + await self._notifier.on_un_partial_stated_room( + room_id, new_stream_id + ) return # we raced against more events arriving with partial state. Go round diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 78d488f2b1..ee11764567 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -290,7 +290,7 @@ class SyncHandler: expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync async def wait_for_sync_for_user( self, @@ -1340,7 +1340,10 @@ class SyncHandler: membership_change_events = [] if since_token: membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude + user_id, + since_token.room_key, + now_token.room_key, + self.rooms_to_exclude_globally, ) mem_last_change_by_room_id: Dict[str, EventBase] = {} @@ -1375,12 +1378,39 @@ class SyncHandler: else: mutable_joined_room_ids.discard(room_id) + # Tweak the set of rooms to return to the client for eager (non-lazy) syncs. + mutable_rooms_to_exclude = set(self.rooms_to_exclude_globally) + if not sync_config.filter_collection.lazy_load_members(): + # Non-lazy syncs should never include partially stated rooms. + # Exclude all partially stated rooms from this sync. + for room_id in mutable_joined_room_ids: + if await self.store.is_partial_state_room(room_id): + mutable_rooms_to_exclude.add(room_id) + + # Incremental eager syncs should additionally include rooms that + # - we are joined to + # - are full-stated + # - became fully-stated at some point during the sync period + # (These rooms will have been omitted during a previous eager sync.) + forced_newly_joined_room_ids = set() + if since_token and not sync_config.filter_collection.lazy_load_members(): + un_partial_stated_rooms = ( + await self.store.get_un_partial_stated_rooms_between( + since_token.un_partial_stated_rooms_key, + now_token.un_partial_stated_rooms_key, + mutable_joined_room_ids, + ) + ) + for room_id in un_partial_stated_rooms: + if not await self.store.is_partial_state_room(room_id): + forced_newly_joined_room_ids.add(room_id) + # Now we have our list of joined room IDs, exclude as configured and freeze joined_room_ids = frozenset( ( room_id for room_id in mutable_joined_room_ids - if room_id not in self.rooms_to_exclude + if room_id not in mutable_rooms_to_exclude ) ) @@ -1397,6 +1427,8 @@ class SyncHandler: since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, + excluded_room_ids=frozenset(mutable_rooms_to_exclude), + forced_newly_joined_room_ids=frozenset(forced_newly_joined_room_ids), membership_change_events=membership_change_events, ) @@ -1834,14 +1866,16 @@ class SyncHandler: # 3. Work out which rooms need reporting in the sync response. ignored_users = await self.store.ignored_users(user_id) if since_token: - room_changes = await self._get_rooms_changed( + room_changes = await self._get_room_changes_for_incremental_sync( sync_result_builder, ignored_users ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + room_changes = await self._get_room_changes_for_initial_sync( + sync_result_builder, ignored_users + ) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1900,7 +1934,7 @@ class SyncHandler: assert since_token - if membership_change_events: + if membership_change_events or sync_result_builder.forced_newly_joined_room_ids: return True stream_id = since_token.room_key.stream @@ -1909,7 +1943,7 @@ class SyncHandler: return True return False - async def _get_rooms_changed( + async def _get_room_changes_for_incremental_sync( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str], @@ -1947,7 +1981,9 @@ class SyncHandler: for event in membership_change_events: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - newly_joined_rooms: List[str] = [] + newly_joined_rooms: List[str] = list( + sync_result_builder.forced_newly_joined_room_ids + ) newly_left_rooms: List[str] = [] room_entries: List[RoomSyncResultBuilder] = [] invited: List[InvitedSyncResult] = [] @@ -2153,7 +2189,7 @@ class SyncHandler: newly_left_rooms, ) - async def _get_all_rooms( + async def _get_room_changes_for_initial_sync( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str], @@ -2178,7 +2214,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, - excluded_rooms=self.rooms_to_exclude, + excluded_rooms=sync_result_builder.excluded_room_ids, ) room_entries = [] @@ -2549,6 +2585,13 @@ class SyncResultBuilder: since_token: The token supplied by user, or None. now_token: The token to sync up to. joined_room_ids: List of rooms the user is joined to + excluded_room_ids: Set of room ids we should omit from the /sync response. + forced_newly_joined_room_ids: + Rooms that should be presented in the /sync response as if they were + newly joined during the sync period, even if that's not the case. + (This is useful if the room was previously excluded from a /sync response, + and now the client should be made aware of it.) + Only used by incremental syncs. # The following mirror the fields in a sync response presence @@ -2565,6 +2608,8 @@ class SyncResultBuilder: since_token: Optional[StreamToken] now_token: StreamToken joined_room_ids: FrozenSet[str] + excluded_room_ids: FrozenSet[str] + forced_newly_joined_room_ids: FrozenSet[str] membership_change_events: List[EventBase] presence: List[UserPresenceState] = attr.Factory(list) diff --git a/synapse/notifier.py b/synapse/notifier.py index 28f0d4a25a..2b0e52f23c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -314,6 +314,32 @@ class Notifier: event_entries.append((entry, event.event_id)) await self.notify_new_room_events(event_entries, max_room_stream_token) + async def on_un_partial_stated_room( + self, + room_id: str, + new_token: int, + ) -> None: + """Used by the resync background processes to wake up all listeners + of this room when it is un-partial-stated. + + It will also notify replication listeners of the change in stream. + """ + + # Wake up all related user stream notifiers + user_streams = self.room_to_user_streams.get(room_id, set()) + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify( + StreamKeyType.UN_PARTIAL_STATED_ROOMS, new_token, time_now_ms + ) + except Exception: + logger.exception("Failed to notify listener") + + # Poke the replication so that other workers also see the write to + # the un-partial-stated rooms stream. + self.notify_replication() + async def notify_new_room_events( self, event_entries: List[Tuple[_PendingRoomEventEntry, str]], diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2a9cb499a4..cc0528bd8e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -260,6 +260,7 @@ class ReplicationDataHandler: self._state_storage_controller.notify_room_un_partial_stated( row.room_id ) + await self.notifier.on_un_partial_stated_room(row.room_id, token) elif stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index aea96e9d24..84f844b79e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -292,6 +292,7 @@ class RelationsWorkerStore(SQLBaseStore): to_device_key=0, device_list_key=0, groups_key=0, + un_partial_stated_rooms_key=0, ) return events[:limit], next_token diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6a65b2a89b..3aa7b94560 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -26,6 +26,7 @@ from typing import ( Mapping, Optional, Sequence, + Set, Tuple, Union, cast, @@ -1294,10 +1295,44 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): instance_name ) + async def get_un_partial_stated_rooms_between( + self, last_id: int, current_id: int, room_ids: Collection[str] + ) -> Set[str]: + """Get all rooms that got un partial stated between `last_id` exclusive and + `current_id` inclusive. + + Returns: + The list of room ids. + """ + + if last_id == current_id: + return set() + + def _get_un_partial_stated_rooms_between_txn( + txn: LoggingTransaction, + ) -> Set[str]: + sql = """ + SELECT DISTINCT room_id FROM un_partial_stated_room_stream + WHERE ? < stream_id AND stream_id <= ? AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + txn.execute(sql + clause, [last_id, current_id] + args) + + return {r[0] for r in txn} + + return await self.db_pool.runInteraction( + "get_un_partial_stated_rooms_between", + _get_un_partial_stated_rooms_between_txn, + ) + async def get_un_partial_stated_rooms_from_stream( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: - """Get updates for caches replication stream. + """Get updates for un partial stated rooms replication stream. Args: instance_name: The writer we want to fetch updates from. Unused @@ -2304,16 +2339,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): (room_id,), ) - async def clear_partial_state_room(self, room_id: str) -> bool: + async def clear_partial_state_room(self, room_id: str) -> Optional[int]: """Clears the partial state flag for a room. Args: room_id: The room whose partial state flag is to be cleared. Returns: - `True` if the partial state flag has been cleared successfully. + The corresponding stream id for the un-partial-stated rooms stream. - `False` if the partial state flag could not be cleared because the room + `None` if the partial state flag could not be cleared because the room still contains events with partial state. """ try: @@ -2324,7 +2359,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, un_partial_state_room_stream_id, ) - return True + return un_partial_state_room_stream_id except self.db_pool.engine.module.IntegrityError as e: # Assume that any `IntegrityError`s are due to partial state events. logger.info( @@ -2332,7 +2367,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, e, ) - return False + return None def _clear_partial_state_room_txn( self, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index f02c1d7ea7..8e2ba7b7b4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,6 +15,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Collection, Dict, FrozenSet, @@ -47,7 +48,13 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import ( + JsonDict, + PersistedEventPosition, + StateMap, + StrCollection, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -385,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): self, user_id: str, membership_list: Collection[str], - excluded_rooms: Optional[List[str]] = None, + excluded_rooms: StrCollection = (), ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -412,10 +419,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) # Now we filter out forgotten and excluded rooms - rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id) if excluded_rooms is not None: - rooms_to_exclude.update(set(excluded_rooms)) + # Take a copy to avoid mutating the in-cache set + rooms_to_exclude = set(rooms_to_exclude) + rooms_to_exclude.update(excluded_rooms) return [room for room in rooms if room.room_id not in rooms_to_exclude] @@ -1169,7 +1178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return count == 0 @cached() - async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: + async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]: """Gets all rooms the user has forgotten. Args: diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 619eb7f601..d7084d2358 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -53,11 +53,15 @@ class EventSources: *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) ) self.store = hs.get_datastores().main + self._instance_name = hs.get_instance_name() def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() + un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token( + self._instance_name + ) token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -70,6 +74,7 @@ class EventSources: device_list_key=device_list_key, # Groups key is unused. groups_key=0, + un_partial_stated_rooms_key=un_partial_stated_rooms_key, ) return token @@ -107,5 +112,6 @@ class EventSources: to_device_key=0, device_list_key=0, groups_key=0, + un_partial_stated_rooms_key=0, ) return token diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index c59eca2430..f82d1cfc29 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -17,6 +17,7 @@ import re import string from typing import ( TYPE_CHECKING, + AbstractSet, Any, ClassVar, Dict, @@ -79,7 +80,7 @@ JsonSerializable = object # Collection[str] that does not include str itself; str being a Sequence[str] # is very misleading and results in bugs. -StrCollection = Union[Tuple[str, ...], List[str], Set[str]] +StrCollection = Union[Tuple[str, ...], List[str], AbstractSet[str]] # Note that this seems to require inheriting *directly* from Interface in order @@ -633,6 +634,7 @@ class StreamKeyType: PUSH_RULES: Final = "push_rules_key" TO_DEVICE: Final = "to_device_key" DEVICE_LIST: Final = "device_list_key" + UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -640,7 +642,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -652,12 +654,13 @@ class StreamToken: 7. `to_device_key`: `274711` 8. `device_list_key`: `265584` 9. `groups_key`: `1` (note that this key is now unused) + 10. `un_partial_stated_rooms_key`: `379` You can see how many of these keys correspond to the various fields in a "/sync" response: ```json { - "next_batch": "s12_4_0_1_1_1_1_4_1", + "next_batch": "s12_4_0_1_1_1_1_4_1_1", "presence": { "events": [] }, @@ -669,7 +672,7 @@ class StreamToken: "!QrZlfIDQLNLdZHqTnt:hs1": { "timeline": { "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1", + "prev_batch": "s10_4_0_1_1_1_1_4_1_1", "limited": false }, "state": { @@ -705,6 +708,7 @@ class StreamToken: device_list_key: int # Note that the groups key is no longer used and may have bogus values. groups_key: int + un_partial_stated_rooms_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -743,6 +747,7 @@ class StreamToken: # serialized so that there will not be confusion in the future # if additional tokens are added. str(self.groups_key), + str(self.un_partial_stated_rooms_key), ] ) @@ -775,7 +780,7 @@ class StreamToken: return attr.evolve(self, **{key: new_value}) -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index e0f5d54aba..453a6e979c 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1831,7 +1831,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -1845,7 +1845,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b4daace556..9222cab198 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -1987,7 +1987,7 @@ class RoomMessageListTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -1998,7 +1998,7 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2728,7 +2728,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by a label on a /messages request.""" self._send_labelled_messages_in_room() - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" @@ -2745,7 +2745,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by the absence of a label on a /messages request.""" self._send_labelled_messages_in_room() - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" @@ -2768,7 +2768,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """ self._send_labelled_messages_in_room() - token = "s0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 0af643ecd9..c9afa0f3dd 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -913,7 +913,9 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase): # We need to manually append the room ID, because we can't know the ID before # creating the room, and we can't set the config after starting the homeserver. - self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id) + self.hs.get_sync_handler().rooms_to_exclude_globally.append( + self.excluded_room_id + ) def test_join_leave(self) -> None: """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of -- cgit 1.5.1 From 4607be0b7b2165710dc2e5e68ec4281b593ca8c5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 24 Jan 2023 15:28:20 +0000 Subject: Request partial joins by default (#14905) * Request partial joins by default This is a little sloppy, but we are trying to gain confidence in faster joins in the upcoming RC. Admins can still opt out by adding the following to their Synapse config: ```yaml experimental: faster_joins: false ``` We may revert this change before the release proper, depending on how testing in the wild goes. * Changelog * Try to fix the backfill test failures * Upgrade notes * Postgres compat? --- changelog.d/14905.feature | 1 + docs/upgrade.md | 13 +++++++++++ synapse/config/experimental.py | 2 +- synapse/storage/databases/main/stream.py | 40 +++++++++++++++++++++++++++----- 4 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14905.feature (limited to 'synapse') diff --git a/changelog.d/14905.feature b/changelog.d/14905.feature new file mode 100644 index 0000000000..f13a4af981 --- /dev/null +++ b/changelog.d/14905.feature @@ -0,0 +1 @@ +Faster joins: request partial joins by default. Admins can opt-out of this for the time being---see the upgrade notes. diff --git a/docs/upgrade.md b/docs/upgrade.md index 0d486a3c82..6316db563b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -90,6 +90,19 @@ process, for example: # Upgrading to v1.76.0 +## Faster joins are enabled by default + +When joining a room for the first time, Synapse 1.76.0rc1 will request a partial join from the other server by default. Previously, server admins had to opt-in to this using an experimental config flag. + +Server admins can opt out of this feature for the time being by setting + +```yaml +experimental: + faster_joins: false +``` + +in their server config. + ## Changes to the account data replication streams Synapse has changed the format of the account data and devices replication diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 89586db763..2590c88cde 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -84,7 +84,7 @@ class ExperimentalConfig(Config): # experimental support for faster joins over federation # (MSC2775, MSC3706, MSC3895) # requires a target server that can provide a partial join response (MSC3706) - self.faster_joins_enabled: bool = experimental.get("faster_joins", False) + self.faster_joins_enabled: bool = experimental.get("faster_joins", True) # MSC3720 (Account status endpoint) self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 63d8350530..d28fc65df9 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -67,7 +67,7 @@ from synapse.storage.database import ( make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached @@ -944,12 +944,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id stream_key """ - sql = ( - "SELECT coalesce(MIN(topological_ordering), 0) FROM events" - " WHERE room_id = ? AND stream_ordering >= ?" - ) + if isinstance(self.database_engine, PostgresEngine): + min_function = "LEAST" + elif isinstance(self.database_engine, Sqlite3Engine): + min_function = "MIN" + else: + raise RuntimeError(f"Unknown database engine {self.database_engine}") + + # This query used to be + # SELECT COALESCE(MIN(topological_ordering), 0) FROM events + # WHERE room_id = ? and events.stream_ordering >= {stream_key} + # which returns 0 if the stream_key is newer than any event in + # the room. That's not wrong, but it seems to interact oddly with backfill, + # requiring a second call to /messages to actually backfill from a remote + # homeserver. + # + # Instead, rollback the stream ordering to that after the most recent event in + # this room. + sql = f""" + WITH fallback(max_stream_ordering) AS ( + SELECT MAX(stream_ordering) + FROM events + WHERE room_id = ? + ) + SELECT COALESCE(MIN(topological_ordering), 0) FROM events + WHERE + room_id = ? + AND events.stream_ordering >= {min_function}( + ?, + (SELECT max_stream_ordering FROM fallback) + ) + """ + row = await self.db_pool.execute( - "get_current_topological_token", None, sql, room_id, stream_key + "get_current_topological_token", None, sql, room_id, room_id, stream_key ) return row[0][0] if row else 0 -- cgit 1.5.1 From a63d4cc9e96c1f5bb9c5bb9fc9119fb137de3b1b Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 25 Jan 2023 13:38:53 +0000 Subject: Make sqlite database migrations transactional again (#14910) #13873 introduced a regression which causes sqlite database migrations to no longer run inside a transaction. Wrap them in a transaction again, to avoid database corruption when migrations are interrupted. Fixes #14909. Signed-off-by: Sean Quah --- changelog.d/14910.bugfix | 1 + synapse/storage/engines/_base.py | 3 +++ synapse/storage/engines/sqlite.py | 5 +++-- 3 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14910.bugfix (limited to 'synapse') diff --git a/changelog.d/14910.bugfix b/changelog.d/14910.bugfix new file mode 100644 index 0000000000..f1f34cd6ba --- /dev/null +++ b/changelog.d/14910.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in Synapse 1.69.0 which can result in database corruption when database migrations are interrupted on sqlite. diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 70e594a68f..bc9ca3a53c 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -132,6 +132,9 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM """Execute a chunk of SQL containing multiple semicolon-delimited statements. This is not provided by DBAPI2, and so needs engine-specific support. + + Some database engines may automatically COMMIT the ongoing transaction both + before and after executing the script. """ ... diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 14260442b6..2f7df85ce4 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -135,13 +135,14 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): > than one statement with it, it will raise a Warning. Use executescript() if > you want to execute multiple SQL statements with one call. - Though the docs for `executescript` warn: + The script is wrapped in transaction control statemnets, since the docs for + `executescript` warn: > If there is a pending transaction, an implicit COMMIT statement is executed > first. No other implicit transaction control is performed; any transaction > control must be added to sql_script. """ - cursor.executescript(script) + cursor.executescript(f"BEGIN TRANSACTION;\n{script}\nCOMMIT;") # Following functions taken from: https://github.com/coleifer/peewee -- cgit 1.5.1 From 8e37ece015c8afd97572bdc742981792b02c6700 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 25 Jan 2023 16:11:06 +0000 Subject: Bump the client-side timeout for /state (#14912) * Bump the client-side timeout for /state to allow faster joins resyncs the chance to complete for large rooms. We have seen this fair poorly (~90s for Matrix HQ's /state) in testing, causing the resync to advance to another HS who hasn't seen our join yet. * Changelog * Milliseconds!!!! --- changelog.d/14912.misc | 1 + synapse/federation/transport/client.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/14912.misc (limited to 'synapse') diff --git a/changelog.d/14912.misc b/changelog.d/14912.misc new file mode 100644 index 0000000000..9dbc6b3424 --- /dev/null +++ b/changelog.d/14912.misc @@ -0,0 +1 @@ +Faster joins: allow the resync process more time to fetch `/state` ids. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 556883f079..682666ab36 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -102,6 +102,10 @@ class TransportLayerClient: destination, path=path, args={"event_id": event_id}, + # This can take a looooooong time for large rooms. Give this a generous + # timeout of 10 minutes to avoid the partial state resync timing out early + # and trying a bunch of servers who haven't seen our join yet. + timeout=600_000, parser=_StateParser(room_version), ) -- cgit 1.5.1 From 3c3ba31507cbff27064ea3c6cf1e7add9583556a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 Jan 2023 15:14:03 -0500 Subject: Add missing type hints for tests.events. (#14904) --- changelog.d/14904.misc | 1 + mypy.ini | 5 ++- synapse/events/utils.py | 3 +- tests/events/test_presence_router.py | 58 +++++++++++++++++------------ tests/events/test_snapshot.py | 17 ++++++--- tests/events/test_utils.py | 71 +++++++++++++++++++----------------- 6 files changed, 91 insertions(+), 64 deletions(-) create mode 100644 changelog.d/14904.misc (limited to 'synapse') diff --git a/changelog.d/14904.misc b/changelog.d/14904.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14904.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 248402532e..13890ce124 100644 --- a/mypy.ini +++ b/mypy.ini @@ -35,8 +35,6 @@ exclude = (?x) |tests/api/test_auth.py |tests/app/test_openid_listener.py |tests/appservice/test_scheduler.py - |tests/events/test_presence_router.py - |tests/events/test_utils.py |tests/federation/test_federation_catch_up.py |tests/federation/test_federation_sender.py |tests/handlers/test_typing.py @@ -86,6 +84,9 @@ disallow_untyped_defs = True [mypy-tests.crypto.*] disallow_untyped_defs = True +[mypy-tests.events.*] +disallow_untyped_defs = True + [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ae57a4df5e..52e4b467e8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -605,10 +605,11 @@ class EventClientSerializer: _PowerLevel = Union[str, int] +PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] def copy_and_fixup_power_levels_contents( - old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] + old_power_levels: PowerLevelsContent, ) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way. diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index b703e4472e..a9893def74 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction @@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, presence, room +from synapse.server import HomeServer from synapse.types import JsonDict, StreamToken, create_requester +from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config from tests.test_utils import simple_async_mock -from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, override_config @attr.s @@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule: } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -96,9 +103,7 @@ class PresenceRouterTestModule: } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -118,9 +123,14 @@ class PresenceRouterTestModule: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client.send_transaction = simple_async_mock({}) @@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.sync_handler = self.hs.get_sync_handler() self.module_api = homeserver.get_module_api() @@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } ) - def test_receiving_all_presence_legacy(self): + def test_receiving_all_presence_legacy(self) -> None: self.receiving_all_presence_test_body() @override_config( @@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ], } ) - def test_receiving_all_presence(self): + def test_receiving_all_presence(self) -> None: self.receiving_all_presence_test_body() - def receiving_all_presence_test_body(self): + def receiving_all_presence_test_body(self) -> None: """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } ) - def test_send_local_online_presence_to_with_module_legacy(self): + def test_send_local_online_presence_to_with_module_legacy(self) -> None: self.send_local_online_presence_to_with_module_test_body() @override_config( @@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ], } ) - def test_send_local_online_presence_to_with_module(self): + def test_send_local_online_presence_to_with_module(self) -> None: self.send_local_online_presence_to_with_module_test_body() - def send_local_online_presence_to_with_module_test_body(self): + def send_local_online_presence_to_with_module_test_body(self) -> None: """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ @@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): continue # EDUs can contain multiple presence updates - for presence_update in edu["content"]["push"]: + for presence_edu in edu["content"]["push"]: # Check for presence updates that contain the user IDs we're after - found_users.add(presence_update["user_id"]) + found_users.add(presence_edu["user_id"]) # Ensure that no offline states are being sent out - self.assertNotEqual(presence_update["presence"], "offline") + self.assertNotEqual(presence_edu["presence"], "offline") self.assertEqual(found_users, expected_users) def send_presence_update( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, access_token: str, presence_state: str, @@ -479,7 +491,7 @@ def send_presence_update( def sync_presence( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, since_token: Optional[StreamToken] = None, ) -> Tuple[List[UserPresenceState], StreamToken]: @@ -500,7 +512,7 @@ def sync_presence( requester = create_requester(user_id) sync_config = generate_sync_config(requester.user.to_string()) sync_result = testcase.get_success( - testcase.sync_handler.wait_for_sync_for_user( + testcase.hs.get_sync_handler().wait_for_sync_for_user( requester, sync_config, since_token ) ) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 8ddce83b83..6687c28e8f 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils.event_injection import create_event @@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase): self.user_tok = self.login("u1", "pass") self.room_id = self.helper.create_room_as(tok=self.user_tok) - def test_serialize_deserialize_msg(self): + def test_serialize_deserialize_msg(self) -> None: """Test that an EventContext for a message event is the same after serialize/deserialize. """ @@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_no_prev(self): + def test_serialize_deserialize_state_no_prev(self) -> None: """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ @@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_prev(self): + def test_serialize_deserialize_state_prev(self) -> None: """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ @@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def _check_serialize_deserialize(self, event, context): + def _check_serialize_deserialize( + self, event: EventBase, context: EventContext + ) -> None: serialized = self.get_success(context.serialize(event, self.store)) d_context = EventContext.deserialize(self._storage_controllers, serialized) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index a79256846f..ff7b349d75 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,21 +13,24 @@ # limitations under the License. import unittest as stdlib_unittest +from typing import Any, List, Mapping, Optional from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import ( + PowerLevelsContent, SerializeEventConfig, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, serialize_event, ) +from synapse.types import JsonDict from synapse.util.frozenutils import freeze -def MockEvent(**kwargs): +def MockEvent(**kwargs: Any) -> EventBase: if "event_id" not in kwargs: kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: @@ -60,7 +63,7 @@ class TestMaybeUpsertEventField(stdlib_unittest.TestCase): class PruneEventTestCase(stdlib_unittest.TestCase): - def run_test(self, evdict, matchdict, **kwargs): + def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None: """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. @@ -74,7 +77,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) - def test_minimal(self): + def test_minimal(self) -> None: self.run_test( {"type": "A", "event_id": "$test:domain"}, { @@ -86,7 +89,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_basic_keys(self): + def test_basic_keys(self) -> None: """Ensure that the keys that should be untouched are kept.""" # Note that some of the values below don't really make sense, but the # pruning of events doesn't worry about the values of any fields (with @@ -138,7 +141,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_unsigned(self): + def test_unsigned(self) -> None: """Ensure that unsigned properties get stripped (except age_ts and replaces_state).""" self.run_test( { @@ -159,7 +162,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_content(self): + def test_content(self) -> None: """The content dictionary should be stripped in most cases.""" self.run_test( {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, @@ -194,7 +197,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_create(self): + def test_create(self) -> None: """Create events are partially redacted until MSC2176.""" self.run_test( { @@ -223,7 +226,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_power_levels(self): + def test_power_levels(self) -> None: """Power level events keep a variety of content keys.""" self.run_test( { @@ -273,7 +276,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_alias_event(self): + def test_alias_event(self) -> None: """Alias events have special behavior up through room version 6.""" self.run_test( { @@ -302,7 +305,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.V6, ) - def test_redacts(self): + def test_redacts(self) -> None: """Redaction events have no special behaviour until MSC2174/MSC2176.""" self.run_test( @@ -328,7 +331,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_join_rules(self): + def test_join_rules(self) -> None: """Join rules events have changed behavior starting with MSC3083.""" self.run_test( { @@ -371,7 +374,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.V8, ) - def test_member(self): + def test_member(self) -> None: """Member events have changed behavior starting with MSC3375.""" self.run_test( { @@ -417,12 +420,12 @@ class PruneEventTestCase(stdlib_unittest.TestCase): class SerializeEventTestCase(stdlib_unittest.TestCase): - def serialize(self, ev, fields): + def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) ) - def test_event_fields_works_with_keys(self): + def test_event_fields_works_with_keys(self) -> None: self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] @@ -430,7 +433,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"room_id": "!foo:bar"}, ) - def test_event_fields_works_with_nested_keys(self): + def test_event_fields_works_with_nested_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -443,7 +446,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"body": "A message"}}, ) - def test_event_fields_works_with_dot_keys(self): + def test_event_fields_works_with_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -456,7 +459,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"key.with.dots": {}}}, ) - def test_event_fields_works_with_nested_dot_keys(self): + def test_event_fields_works_with_nested_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -472,7 +475,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) - def test_event_fields_nops_with_unknown_keys(self): + def test_event_fields_nops_with_unknown_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -485,7 +488,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"foo": "bar"}}, ) - def test_event_fields_nops_with_non_dict_keys(self): + def test_event_fields_nops_with_non_dict_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -498,7 +501,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {}, ) - def test_event_fields_nops_with_array_keys(self): + def test_event_fields_nops_with_array_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -511,7 +514,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {}, ) - def test_event_fields_all_fields_if_empty(self): + def test_event_fields_all_fields_if_empty(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -531,16 +534,16 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): }, ) - def test_event_fields_fail_if_fields_not_str(self): + def test_event_fields_fail_if_fields_not_str(self) -> None: with self.assertRaises(TypeError): self.serialize( - MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] + MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item] ) class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: - self.test_content = { + self.test_content: PowerLevelsContent = { "ban": 50, "events": {"m.room.name": 100, "m.room.power_levels": 100}, "events_default": 0, @@ -553,10 +556,11 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): "users_default": 0, } - def _test(self, input): + def _test(self, input: PowerLevelsContent) -> None: a = copy_and_fixup_power_levels_contents(input) self.assertEqual(a["ban"], 50) + assert isinstance(a["events"], Mapping) self.assertEqual(a["events"]["m.room.name"], 100) # make sure that changing the copy changes the copy and not the orig @@ -564,18 +568,19 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): a["events"]["m.room.power_levels"] = 20 self.assertEqual(input["ban"], 50) + assert isinstance(input["events"], Mapping) self.assertEqual(input["events"]["m.room.power_levels"], 100) - def test_unfrozen(self): + def test_unfrozen(self) -> None: self._test(self.test_content) - def test_frozen(self): + def test_frozen(self) -> None: input = freeze(self.test_content) self._test(input) - def test_stringy_integers(self): + def test_stringy_integers(self) -> None: """String representations of decimal integers are converted to integers.""" - input = { + input: PowerLevelsContent = { "a": "100", "b": { "foo": 99, @@ -603,9 +608,9 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def test_invalid_types_raise_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[arg-type] - copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[arg-type] + copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[dict-item] + copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[dict-item] def test_invalid_nesting_raises_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) + copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item] -- cgit 1.5.1 From 7e8d455280b58dbda3ff24b19dbffad2d6c6c253 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 Jan 2023 16:34:37 -0500 Subject: Fix a bug in the send_local_online_presence_to module API (#14880) Destination was being used incorrectly (a single destination instead of a list of destinations was being passed). This also updates some of the types in the area to not use Collection[str], which is a footgun. --- changelog.d/14880.bugfix | 1 + synapse/handlers/presence.py | 18 ++++++++++++------ synapse/module_api/__init__.py | 2 +- synapse/notifier.py | 3 ++- synapse/streams/__init__.py | 6 +++--- 5 files changed, 19 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14880.bugfix (limited to 'synapse') diff --git a/changelog.d/14880.bugfix b/changelog.d/14880.bugfix new file mode 100644 index 0000000000..e56c567082 --- /dev/null +++ b/changelog.d/14880.bugfix @@ -0,0 +1 @@ +Fix a bug when using the `send_local_online_presence_to` module API. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 43e4e7b1b4..87af31aa27 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + StrCollection, + StreamKeyType, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC): for destination, host_states in hosts_to_states.items(): self._federation.send_presence_to_destinations(host_states, [destination]) - async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None: + async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ Adds to the list of users who should receive a full snapshot of presence upon their next sync. Note that this only works for local users. @@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # Having a default limit doesn't match the EventSource API, but some # callers do not provide it. It is unused in this class. limit: int = 0, - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, @@ -1688,7 +1694,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users: Collection[str] + interested_and_updated_users: StrCollection if from_key is not None: # First get all users that have had a presence update @@ -2120,7 +2126,7 @@ class PresenceFederationQueue: # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] + self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = [] self._next_id = 1 @@ -2142,7 +2148,7 @@ class PresenceFederationQueue: self._queue = self._queue[index:] def send_presence_to_destinations( - self, states: Collection[UserPresenceState], destinations: Collection[str] + self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6153a48257..d22dd19d38 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1158,7 +1158,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain presence_handler.get_federation_queue().send_presence_to_destinations( - presence_events, destination + presence_events, [destination] ) def looping_background_call( diff --git a/synapse/notifier.py b/synapse/notifier.py index 2b0e52f23c..a8832a3f8e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -46,6 +46,7 @@ from synapse.types import ( JsonDict, PersistedEventPosition, RoomStreamToken, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -716,7 +717,7 @@ class Notifier: async def _get_room_ids( self, user: UserID, explicit_room_id: Optional[str] - ) -> Tuple[Collection[str], bool]: + ) -> Tuple[StrCollection, bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 2dcd43d0a2..c6c8a0315c 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Generic, List, Optional, Tuple, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar -from synapse.types import UserID +from synapse.types import StrCollection, UserID # The key, this is either a stream token or int. K = TypeVar("K") @@ -28,7 +28,7 @@ class EventSource(Generic[K, R]): user: UserID, from_key: K, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: -- cgit 1.5.1 From cf66d712c615b96bce19e44118cce1ebda41d0b8 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 26 Jan 2023 10:38:49 +0000 Subject: Fix initialization of `_device_list_id_gen` (#14914) On startup, the `_device_list_id_gen` stream id generator is initialized using the maximum stream id seen in a list of tables. When we started populating the `device_list_remote_pending` table in #13913, we forgot to add it to the aforementioned list of tables, so the stream id generator can hand out old stream ids after a restart. The end result is that Synapse can fail to handle device list update EDUs after a restart when a partial state join is in progress. Add the `device_list_remote_pending` table to the list of tables to consider when initializing the `_device_list_id_gen` stream id generator. Signed-off-by: Sean Quah --- changelog.d/14914.bugfix | 1 + synapse/storage/databases/main/devices.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/14914.bugfix (limited to 'synapse') diff --git a/changelog.d/14914.bugfix b/changelog.d/14914.bugfix new file mode 100644 index 0000000000..af73cca70f --- /dev/null +++ b/changelog.d/14914.bugfix @@ -0,0 +1 @@ +Faster joins: Fix a bug introduced in Synapse 1.69 where device list EDUs could fail to be handled after a restart when a faster join sync is in progress. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 903606fb46..e8b6cc6b80 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -99,6 +99,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ("device_lists_changes_in_room", "stream_id"), + ("device_lists_remote_pending", "stream_id"), ], is_writer=hs.config.worker.worker_app is None, ) -- cgit 1.5.1 From 8a05d5de21888cdd0b53870fead3a1eae64f0b17 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Jan 2023 12:15:36 -0500 Subject: Batch look-ups to see if rooms are partial stated. (#14917) * Batch look-ups to see if rooms are partial stated. * Fix issues found in linting. * Fix typo. * Apply suggestions from code review Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Clarify comments. Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Also improve the cache size while we're at it * is_partial_state_rooms -> is_partial_state_room_batched * Run `black` * Improve annotation for `simple_select_many_batch` * Fix is_partial_state_room_batched impl * Okay, _actually_ fix impl * Update description. * Update synapse/storage/databases/main/room.py Co-authored-by: Patrick Cloke * Run black. Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> Co-authored-by: David Robertson --- changelog.d/14917.misc | 1 + synapse/handlers/sync.py | 24 +++++++++++++++++------- synapse/storage/database.py | 2 +- synapse/storage/databases/main/room.py | 27 ++++++++++++++++++++++++--- 4 files changed, 43 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14917.misc (limited to 'synapse') diff --git a/changelog.d/14917.misc b/changelog.d/14917.misc new file mode 100644 index 0000000000..4d1dd2639a --- /dev/null +++ b/changelog.d/14917.misc @@ -0,0 +1 @@ +Faster joins: Improve performance of looking up partial-state status of rooms. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ee11764567..5ebd3ea855 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1383,16 +1383,21 @@ class SyncHandler: if not sync_config.filter_collection.lazy_load_members(): # Non-lazy syncs should never include partially stated rooms. # Exclude all partially stated rooms from this sync. - for room_id in mutable_joined_room_ids: - if await self.store.is_partial_state_room(room_id): - mutable_rooms_to_exclude.add(room_id) + results = await self.store.is_partial_state_room_batched( + mutable_joined_room_ids + ) + mutable_rooms_to_exclude.update( + room_id + for room_id, is_partial_state in results.items() + if is_partial_state + ) # Incremental eager syncs should additionally include rooms that # - we are joined to # - are full-stated # - became fully-stated at some point during the sync period # (These rooms will have been omitted during a previous eager sync.) - forced_newly_joined_room_ids = set() + forced_newly_joined_room_ids: Set[str] = set() if since_token and not sync_config.filter_collection.lazy_load_members(): un_partial_stated_rooms = ( await self.store.get_un_partial_stated_rooms_between( @@ -1401,9 +1406,14 @@ class SyncHandler: mutable_joined_room_ids, ) ) - for room_id in un_partial_stated_rooms: - if not await self.store.is_partial_state_room(room_id): - forced_newly_joined_room_ids.add(room_id) + results = await self.store.is_partial_state_room_batched( + un_partial_stated_rooms + ) + forced_newly_joined_room_ids.update( + room_id + for room_id, is_partial_state in results.items() + if not is_partial_state + ) # Now we have our list of joined room IDs, exclude as configured and freeze joined_room_ids = frozenset( diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 88479a16db..e20c5c5302 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1819,7 +1819,7 @@ class DatabasePool: keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, - ) -> List[Any]: + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3aa7b94560..fbbc018887 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -60,9 +60,9 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID +from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.stringutils import MXC_REGEX if TYPE_CHECKING: @@ -1255,7 +1255,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return room_servers - @cached() + @cached(max_entries=10000) async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. @@ -1274,6 +1274,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return entry is not None + @cachedList(cached_method_name="is_partial_state_room", list_name="room_ids") + async def is_partial_state_room_batched( + self, room_ids: StrCollection + ) -> Mapping[str, bool]: + """Checks if the given rooms have partial state. + + Returns true for "partial-state" rooms, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch( + table="partial_state_rooms", + column="room_id", + iterable=room_ids, + retcols=("room_id",), + desc="is_partial_state_room_batched", + ) + partial_state_rooms = {row_dict["room_id"] for row_dict in rows} + return {room_id: room_id in partial_state_rooms for room_id in room_ids} + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( self, room_id: str ) -> Tuple[str, int]: -- cgit 1.5.1 From ba79fb4a61784f4b5613da795a61f430af053ca6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Jan 2023 12:31:58 -0500 Subject: Use StrCollection in place of Collection[str] in (most) handlers code. (#14922) Due to the increased safety of StrCollection over Collection[str] and Sequence[str]. --- changelog.d/14922.misc | 1 + synapse/handlers/account_data.py | 6 +++--- synapse/handlers/device.py | 6 +++--- synapse/handlers/event_auth.py | 8 ++++---- synapse/handlers/federation.py | 26 ++++++++------------------ synapse/handlers/federation_event.py | 5 +++-- synapse/handlers/pagination.py | 6 +++--- synapse/handlers/room.py | 14 +++----------- synapse/handlers/room_summary.py | 4 ++-- synapse/handlers/search.py | 8 ++++---- synapse/handlers/sso.py | 9 +++++---- synapse/handlers/sync.py | 4 ++-- synapse/rest/client/push_rule.py | 4 ++-- 13 files changed, 43 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14922.misc (limited to 'synapse') diff --git a/changelog.d/14922.misc b/changelog.d/14922.misc new file mode 100644 index 0000000000..2cc3614dfd --- /dev/null +++ b/changelog.d/14922.misc @@ -0,0 +1 @@ +Use `StrCollection` to avoid potential bugs with `Collection[str]`. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 834006356a..d500b21809 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( @@ -26,7 +26,7 @@ from synapse.replication.http.account_data import ( ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): user: UserID, from_key: int, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 58180ae2fa..5c06073901 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -18,7 +18,6 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, - Collection, Dict, Iterable, List, @@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -146,7 +146,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( - self, user_id: str, room_ids: Collection[str], from_token: StreamToken + self, user_id: str, room_ids: StrCollection, from_token: StreamToken ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. @@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") async def notify_device_update( - self, user_id: str, device_ids: Collection[str] + self, user_id: str, device_ids: StrCollection ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index f91dbbecb7..a23a8ce2a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap, get_domain_from_id +from synapse.types import StateMap, StrCollection, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,7 +290,7 @@ class EventAuthHandler: async def get_rooms_that_allow_join( self, state_ids: StateMap[str] - ) -> Collection[str]: + ) -> StrCollection: """ Generate a list of rooms in which membership allows access to a room. @@ -331,7 +331,7 @@ class EventAuthHandler: return result - async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool: + async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool: """ Check whether a user is a member of any of the provided rooms. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 233f8c113d..dc1cbf5c3d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,17 +20,7 @@ import itertools import logging from enum import Enum from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Histogram @@ -70,7 +60,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -179,7 +169,7 @@ class FederationHandler: # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: Dict[ - str, Tuple[Optional[str], Collection[str]] + str, Tuple[Optional[str], StrCollection] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -437,7 +427,7 @@ class FederationHandler: ) ) - async def try_backfill(domains: Collection[str]) -> bool: + async def try_backfill(domains: StrCollection) -> bool: # TODO: Should we try multiple of these at a time? # Number of contacted remote homeservers that have denied our backfill @@ -1730,7 +1720,7 @@ class FederationHandler: def _start_partial_state_room_sync( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Starts the background process to resync the state of a partial state room, @@ -1812,7 +1802,7 @@ class FederationHandler: async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1949,9 +1939,9 @@ class FederationHandler: def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, -) -> Collection[str]: +) -> StrCollection: """Work out the order in which we should ask servers to resync events. If an `initial_destination` is given, it takes top priority. Otherwise diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 904a721483..e037acbca2 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -80,6 +80,7 @@ from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, + StrCollection, UserID, get_domain_from_id, ) @@ -615,7 +616,7 @@ class FederationEventHandler: @trace async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Collection[str] + self, dest: str, room_id: str, limit: int, extremities: StrCollection ) -> None: """Trigger a backfill request to `dest` for the given `room_id` @@ -1565,7 +1566,7 @@ class FederationEventHandler: @trace @tag_args async def _get_events_and_persist( - self, destination: str, room_id: str, event_ids: Collection[str] + self, destination: str, room_id: str, event_ids: StrCollection ) -> None: """Fetch the given events from a server, and persist them as outliers. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8c8ff18a1a..1fe6567185 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import attr @@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string @@ -391,7 +391,7 @@ class PaginationHandler: """ return self._delete_by_id.get(delete_id) - def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]: """Get all active delete ids by room Args: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 572c7b4db3..60a6d9cf3c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,16 +20,7 @@ import random import string from collections import OrderedDict from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Collection, - Dict, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple import attr from typing_extensions import TypedDict @@ -72,6 +63,7 @@ from synapse.types import ( RoomID, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): user: UserID, from_key: RoomStreamToken, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index c6b869c6f4..4472019fbc 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -870,7 +870,7 @@ class _RoomQueueEntry: # The room ID of this entry. room_id: str # The server to query if the room is not known locally. - via: Sequence[str] + via: StrCollection # The minimum number of hops necessary to get to this room (compared to the # originally requested room). depth: int = 0 diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 40f4635c4e..9bbf83047d 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,7 +14,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from unpaddedbase64 import decode_base64, encode_base64 @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -418,7 +418,7 @@ class SearchHandler: async def _search_by_rank( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, @@ -491,7 +491,7 @@ class SearchHandler: async def _search_by_recent( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 44e70fc4b8..4a27c0f051 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -20,7 +20,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import ( JsonDict, + StrCollection, UserID, contains_invalid_mxid_characters, create_requester, @@ -141,7 +141,8 @@ class UserAttributes: confirm_localpart: bool = False display_name: Optional[str] = None picture: Optional[str] = None - emails: Collection[str] = attr.Factory(list) + # mypy thinks these are incompatible for some reason. + emails: StrCollection = attr.Factory(list) # type: ignore[assignment] @attr.s(slots=True, auto_attribs=True) @@ -159,7 +160,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name: Optional[str] - emails: Collection[str] + emails: StrCollection # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -174,7 +175,7 @@ class UsernameMappingSession: # choices made by the user chosen_localpart: Optional[str] = None use_display_name: bool = True - emails_to_use: Collection[str] = () + emails_to_use: StrCollection = () terms_accepted_version: Optional[str] = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ee11764567..9e9601d423 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, AbstractSet, Any, - Collection, Dict, FrozenSet, List, @@ -62,6 +61,7 @@ from synapse.types import ( Requester, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1179,7 +1179,7 @@ class SyncHandler: async def _find_missing_partial_state_memberships( self, room_id: str, - members_to_fetch: Collection[str], + members_to_fetch: StrCollection, events_with_membership_auth: Mapping[str, EventBase], found_state_ids: StateMap[str], ) -> StateMap[str]: diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8191b4e32c..ad5c10c99d 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet): raise UnrecognizedRequestError() -def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: +def _rule_spec_from_path(path: List[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: -- cgit 1.5.1 From 345576bc349f2c96b273bea246a5bb44c705c6ec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Jan 2023 13:24:15 -0500 Subject: Fix paginating /relations with a live token (#14866) The `/relations` endpoint was not properly handle "live tokens" (i.e sync tokens), to do this properly we abstract the code that `/messages` has and re-use it. --- changelog.d/14866.bugfix | 1 + synapse/storage/databases/main/relations.py | 38 +++---- synapse/storage/databases/main/stream.py | 154 +++++++++++++++++++--------- 3 files changed, 123 insertions(+), 70 deletions(-) create mode 100644 changelog.d/14866.bugfix (limited to 'synapse') diff --git a/changelog.d/14866.bugfix b/changelog.d/14866.bugfix new file mode 100644 index 0000000000..540f918cbd --- /dev/null +++ b/changelog.d/14866.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.53.0 where `next_batch` tokens from `/sync` could not be used with the `/relations` endpoint. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 84f844b79e..be2242b6ac 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -40,9 +40,13 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.databases.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import ( + generate_next_token, + generate_pagination_bounds, + generate_pagination_where_clause, +) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) + order, from_bound, to_bound = generate_pagination_bounds( + direction, + from_token.room_key if from_token else None, + to_token.room_key if to_token else None, + ) + pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.room_key.as_historical_tuple() - if from_token - else None, - to_token=to_token.room_key.as_historical_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) if pagination_clause: where_clause.append(pagination_clause) - if direction == "b": - order = "DESC" - else: - order = "ASC" - sql = """ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations @@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore): topo_orderings = topo_orderings[:limit] stream_orderings = stream_orderings[:limit] - topo = topo_orderings[-1] - token = stream_orderings[-1] - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_key = RoomStreamToken(topo, token) + next_key = generate_next_token( + direction, topo_orderings[-1], stream_orderings[-1] + ) if from_token: next_token = from_token.copy_and_replace( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index d28fc65df9..8977bf33e7 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -170,6 +170,104 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) +def generate_pagination_bounds( + direction: str, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], +) -> Tuple[ + str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]] +]: + """ + Generate a start and end point for this page of events. + + Args: + direction: Whether pagination is going forwards or backwards. One of "f" or "b". + from_token: The token to start pagination at, or None to start at the first value. + to_token: The token to end pagination at, or None to not limit the end point. + + Returns: + A three tuple of: + + ASC or DESC for sorting of the query. + + The starting position as a tuple of ints representing + (topological position, stream position) or None if no from_token was + provided. The topological position may be None for live tokens. + + The end position in the same format as the starting position, or None + if no to_token was provided. + """ + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + if direction == "b": + order = "DESC" + else: + order = "ASC" + + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + from_bound: Optional[Tuple[Optional[int], int]] = None + if from_token: + if from_token.topological is not None: + from_bound = from_token.as_historical_tuple() + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound: Optional[Tuple[Optional[int], int]] = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + + return order, from_bound, to_bound + + +def generate_next_token( + direction: str, last_topo_ordering: int, last_stream_ordering: int +) -> RoomStreamToken: + """ + Generate the next room stream token based on the currently returned data. + + Args: + direction: Whether pagination is going forwards or backwards. One of "f" or "b". + last_topo_ordering: The last topological ordering being returned. + last_stream_ordering: The last stream ordering being returned. + + Returns: + A new RoomStreamToken to return to the client. + """ + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + last_stream_ordering -= 1 + return RoomStreamToken(last_topo_ordering, last_stream_ordering) + + def _make_generic_sql_bound( bound: str, column_names: Tuple[str, str], @@ -1300,47 +1398,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - - # The bounds for the stream tokens are complicated by the fact - # that we need to handle the instance_map part of the tokens. We do this - # by fetching all events between the min stream token and the maximum - # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and - # then filtering the results. - if from_token.topological is not None: - from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() - elif direction == "b": - from_bound = ( - None, - from_token.get_max_stream_pos(), - ) - else: - from_bound = ( - None, - from_token.stream, - ) - to_bound: Optional[Tuple[Optional[int], int]] = None - if to_token: - if to_token.topological is not None: - to_bound = to_token.as_historical_tuple() - elif direction == "b": - to_bound = ( - None, - to_token.stream, - ) - else: - to_bound = ( - None, - to_token.get_max_stream_pos(), - ) + order, from_bound, to_bound = generate_pagination_bounds( + direction, from_token, to_token + ) bounds = generate_pagination_where_clause( direction=direction, @@ -1436,16 +1498,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ][:limit] if rows: - topo = rows[-1].topological_ordering - token = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_token = RoomStreamToken(topo, token) + assert rows[-1].topological_ordering is not None + next_token = generate_next_token( + direction, rows[-1].topological_ordering, rows[-1].stream_ordering + ) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token -- cgit 1.5.1 From fc35e0673f5b46ea0f5e53ef15626b14a452ca82 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Jan 2023 14:45:24 -0500 Subject: Add missing type hints in tests (#14879) * FIx-up type hints in tests.logging. * Add missing type hints to test_transactions. --- changelog.d/14879.misc | 1 + mypy.ini | 6 ++--- synapse/rest/client/transactions.py | 3 ++- tests/logging/__init__.py | 6 +++-- tests/logging/test_opentracing.py | 4 ++-- tests/logging/test_remote_handler.py | 25 +++++++++++++------- tests/logging/test_terse_json.py | 30 ++++++++++++++---------- tests/rest/client/test_transactions.py | 42 ++++++++++++++++++++++------------ 8 files changed, 75 insertions(+), 42 deletions(-) create mode 100644 changelog.d/14879.misc (limited to 'synapse') diff --git a/changelog.d/14879.misc b/changelog.d/14879.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14879.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index e57bc64261..978d92940b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -40,10 +40,7 @@ exclude = (?x) |tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_srv_resolver.py |tests/http/test_proxyagent.py - |tests/logging/__init__.py - |tests/logging/test_terse_json.py |tests/module_api/test_api.py - |tests/rest/client/test_transactions.py |tests/rest/media/v1/test_media_storage.py |tests/server.py |tests/test_state.py @@ -92,6 +89,9 @@ disallow_untyped_defs = True [mypy-tests.handlers.*] disallow_untyped_defs = True +[mypy-tests.logging.*] +disallow_untyped_defs = True + [mypy-tests.metrics.*] disallow_untyped_defs = True diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 61375651bc..3f40f1874a 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple from typing_extensions import ParamSpec +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.web.server import Request @@ -90,7 +91,7 @@ class HttpTransactionCache: fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> "Deferred[Tuple[int, JsonDict]]": """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py index 1acf5666a8..1c5de95a80 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. import logging +from tests.unittest import TestCase -class LoggerCleanupMixin: - def get_logger(self, handler): + +class LoggerCleanupMixin(TestCase): + def get_logger(self, handler: logging.Handler) -> logging.Logger: """ Attach a handler to a logger and add clean-ups to remove revert this. """ diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 0917e478a5..e28ba84cc2 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -153,7 +153,7 @@ class LogContextScopeManagerTestCase(TestCase): scopes = [] - async def task(i: int): + async def task(i: int) -> None: scope = start_active_span( f"task{i}", tracer=self._tracer, @@ -165,7 +165,7 @@ class LogContextScopeManagerTestCase(TestCase): self.assertEqual(self._tracer.active_span, scope.span) scope.close() - async def root(): + async def root() -> None: with start_active_span("root span", tracer=self._tracer) as root_scope: self.assertEqual(self._tracer.active_span, root_scope.span) scopes.append(root_scope) diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index b0d046fe00..c08954d887 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.test.proto_helpers import AccumulatingProtocol +from typing import Tuple + +from twisted.internet.protocol import Protocol +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from synapse.logging import RemoteHandler @@ -20,7 +23,9 @@ from tests.server import FakeTransport, get_clock from tests.unittest import TestCase -def connect_logging_client(reactor, client_id): +def connect_logging_client( + reactor: MemoryReactorClock, client_id: int +) -> Tuple[Protocol, AccumulatingProtocol]: # This is essentially tests.server.connect_client, but disabling autoflush on # the client transport. This is necessary to avoid an infinite loop due to # sending of data via the logging transport causing additional logs to be @@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id): class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, _ = get_clock() - def test_log_output(self): + def test_log_output(self) -> None: """ The remote handler delivers logs over TCP. """ @@ -51,6 +56,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent + assert isinstance(client.transport, FakeTransport) client.transport.flush() # One log message, with a single trailing newline @@ -61,7 +67,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Ensure the data passed through properly. self.assertEqual(logs[0], "Hello there, wally!") - def test_log_backpressure_debug(self): + def test_log_backpressure_debug(self) -> None: """ When backpressure is hit, DEBUG logs will be shed. """ @@ -83,6 +89,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # Only the 7 infos made it through, the debugs were elided @@ -90,7 +97,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(len(logs), 7) self.assertNotIn(b"debug", server.data) - def test_log_backpressure_info(self): + def test_log_backpressure_info(self) -> None: """ When backpressure is hit, DEBUG and INFO logs will be shed. """ @@ -116,6 +123,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The 10 warnings made it through, the debugs and infos were elided @@ -124,7 +132,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): self.assertNotIn(b"debug", server.data) self.assertNotIn(b"info", server.data) - def test_log_backpressure_cut_middle(self): + def test_log_backpressure_cut_middle(self) -> None: """ When backpressure is hit, and no more DEBUG and INFOs cannot be culled, it will cut the middle messages out. @@ -140,6 +148,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The first five and last five warnings made it through, the debugs and @@ -151,7 +160,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): logs, ) - def test_cancel_connection(self): + def test_cancel_connection(self) -> None: """ Gracefully handle the connection being cancelled. """ diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 0b0d8737c1..fa27f1279a 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -14,24 +14,28 @@ import json import logging from io import BytesIO, StringIO +from typing import cast from unittest.mock import Mock, patch +from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.http.site import SynapseRequest from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging.context import LoggingContext, LoggingContextFilter +from synapse.types import JsonDict from tests.logging import LoggerCleanupMixin -from tests.server import FakeChannel +from tests.server import FakeChannel, get_clock from tests.unittest import TestCase class TerseJsonTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.output = StringIO() + self.reactor, _ = get_clock() - def get_log_line(self): + def get_log_line(self) -> JsonDict: # One log message, with a single trailing newline. data = self.output.getvalue() logs = data.splitlines() @@ -39,7 +43,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(data.count("\n"), 1) return json.loads(logs[0]) - def test_terse_json_output(self): + def test_terse_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -61,7 +65,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_extra_data(self): + def test_extra_data(self) -> None: """ Additional information can be included in the structured logging. """ @@ -93,7 +97,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["int"], 3) self.assertIs(log["bool"], True) - def test_json_output(self): + def test_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -114,7 +118,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_with_context(self): + def test_with_context(self) -> None: """ The logging context should be added to the JSON response. """ @@ -139,7 +143,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "name") - def test_with_request_context(self): + def test_with_request_context(self) -> None: """ Information from the logging context request should be added to the JSON response. """ @@ -154,11 +158,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.server_version_string = "Server v1" site.reactor = Mock() site.experimental_cors_msc3886 = False - request = SynapseRequest(FakeChannel(site, None), site) + request = SynapseRequest( + cast(HTTPChannel, FakeChannel(site, self.reactor)), site + ) # Call requestReceived to finish instantiating the object. request.content = BytesIO() - # Partially skip some of the internal processing of SynapseRequest. - request._started_processing = Mock() + # Partially skip some internal processing of SynapseRequest. + request._started_processing = Mock() # type: ignore[assignment] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") @@ -200,7 +206,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["protocol"], "1.1") self.assertEqual(log["user_agent"], "") - def test_with_exception(self): + def test_with_exception(self) -> None: """ The logging exception type & value should be added to the JSON response. """ diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 21a1ca2a68..3086e1b565 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -13,18 +13,22 @@ # limitations under the License. from http import HTTPStatus +from typing import Any, Generator, Tuple, cast from unittest.mock import Mock, call -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as _reactor from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable from tests.utils import MockClock +reactor = cast(ISynapseReactor, _reactor) + class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self) -> None: @@ -34,11 +38,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) self.mock_key = "foo" @defer.inlineCallbacks - def test_executes_given_function(self): + def test_executes_given_function( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" @@ -47,7 +53,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertEqual(res, self.mock_http_response) @defer.inlineCallbacks - def test_deduplicates_based_on_key(self): + def test_deduplicates_based_on_key( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( @@ -58,18 +66,20 @@ class HttpTransactionCacheTestCase(unittest.TestCase): cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0) @defer.inlineCallbacks - def test_logcontexts_with_async_result(self): + def test_logcontexts_with_async_result( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks - def cb(): + def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: yield Clock(reactor).sleep(0) - return "yay" + return 1, {} @defer.inlineCallbacks - def test(): + def test() -> Generator["defer.Deferred[Any]", object, None]: with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertIs(current_context(), c1) - self.assertEqual(res, "yay") + self.assertEqual(res, (1, {})) # run the test twice in parallel d = defer.gatherResults([test(), test()]) @@ -78,13 +88,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks - def test_does_not_cache_exceptions(self): + def test_does_not_cache_exceptions( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback throws an exception, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -104,13 +116,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_does_not_cache_failures(self): + def test_does_not_cache_failures( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback returns a failure, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -130,7 +144,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_cleans_up(self): + def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet -- cgit 1.5.1 From 265735db9d7b0698a511fc9389db4d6f104f1aa8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 Jan 2023 07:27:55 -0500 Subject: Use an enum for direction. (#14927) For better type safety we use an enum instead of strings to configure direction (backwards or forwards). --- changelog.d/14927.misc | 1 + synapse/api/constants.py | 7 ++++ synapse/handlers/admin.py | 4 +- synapse/handlers/initial_sync.py | 16 +++++++- synapse/handlers/pagination.py | 6 +-- synapse/handlers/relations.py | 8 +++- synapse/storage/databases/main/relations.py | 8 ++-- synapse/storage/databases/main/stream.py | 59 +++++++++++++++-------------- synapse/streams/config.py | 11 ++++-- 9 files changed, 76 insertions(+), 44 deletions(-) create mode 100644 changelog.d/14927.misc (limited to 'synapse') diff --git a/changelog.d/14927.misc b/changelog.d/14927.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14927.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6432d32d83..6f9239d21c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +import enum + from typing_extensions import Final # the max size of a (canonical-json-encoded) event @@ -290,3 +292,8 @@ class ApprovalNoticeMedium: NONE = "org.matrix.msc3866.none" EMAIL = "org.matrix.msc3866.email" + + +class Direction(enum.Enum): + BACKWARDS = "b" + FORWARDS = "f" diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5bf8e86387..c81ea34758 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client @@ -197,7 +197,7 @@ class AdminHandler: # efficient method perhaps but it does guarantee we get everything. while True: events, _ = await self.store.paginate_room_events( - room_id, from_key, to_key, limit=100, direction="f" + room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: break diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8c2260ad7d..191529bd8e 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + Direction, + EduTypes, + EventTypes, + Membership, +) from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -57,7 +63,13 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool + str, + Optional[StreamToken], + Optional[StreamToken], + Direction, + int, + bool, + bool, ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1fe6567185..ceefa16b49 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -19,7 +19,7 @@ import attr from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig @@ -448,7 +448,7 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token - elif pagin_config.direction == "f": + elif pagin_config.direction == Direction.FORWARDS: from_token = ( await self.hs.get_event_sources().get_start_token_for_pagination( room_id @@ -476,7 +476,7 @@ class PaginationHandler: room_id, requester, allow_departed_users=True ) - if pagin_config.direction == "b": + if pagin_config.direction == Direction.BACKWARDS: # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index e96f9999a8..0fb15391e0 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O import attr -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -413,7 +413,11 @@ class RelationsHandler: # Attempt to find another event to use as the latest event. potential_events, _ = await self._main_store.get_relations_for_event( - event_id, event, room_id, RelationTypes.THREAD, direction="f" + event_id, + event, + room_id, + RelationTypes.THREAD, + direction=Direction.FORWARDS, ) # Filter out ignored users. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be2242b6ac..0018d6f7ab 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,7 @@ from typing import ( import attr -from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -168,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: @@ -181,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 8977bf33e7..818c46182e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -55,6 +55,7 @@ from typing_extensions import Literal from twisted.internet import defer +from synapse.api.constants import Direction from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" - # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: @@ -104,7 +104,7 @@ class _EventsAround: def generate_pagination_where_clause( - direction: str, + direction: Direction, column_names: Tuple[str, str], from_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]], @@ -130,27 +130,26 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction: Whether we're paginating backwards("b") or forwards ("f"). + direction: Whether we're paginating backwards or forwards. column_names: The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. from_token: The start point for the pagination. This is an exclusive - minimum bound if direction is "f", and an inclusive maximum bound if - direction is "b". + minimum bound if direction is forwards, and an inclusive maximum bound if + direction is backwards. to_token: The endpoint point for the pagination. This is an inclusive - maximum bound if direction is "f", and an exclusive minimum bound if - direction is "b". + maximum bound if direction is forwards, and an exclusive minimum bound if + direction is backwards. engine: The database engine to generate the clauses for Returns: The sql expression """ - assert direction in ("b", "f") where_clause = [] if from_token: where_clause.append( _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", + bound=">=" if direction == Direction.BACKWARDS else "<", column_names=column_names, values=from_token, engine=engine, @@ -160,7 +159,7 @@ def generate_pagination_where_clause( if to_token: where_clause.append( _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", + bound="<" if direction == Direction.BACKWARDS else ">=", column_names=column_names, values=to_token, engine=engine, @@ -171,7 +170,7 @@ def generate_pagination_where_clause( def generate_pagination_bounds( - direction: str, + direction: Direction, from_token: Optional[RoomStreamToken], to_token: Optional[RoomStreamToken], ) -> Tuple[ @@ -181,7 +180,7 @@ def generate_pagination_bounds( Generate a start and end point for this page of events. Args: - direction: Whether pagination is going forwards or backwards. One of "f" or "b". + direction: Whether pagination is going forwards or backwards. from_token: The token to start pagination at, or None to start at the first value. to_token: The token to end pagination at, or None to not limit the end point. @@ -201,7 +200,7 @@ def generate_pagination_bounds( # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" @@ -215,7 +214,7 @@ def generate_pagination_bounds( if from_token: if from_token.topological is not None: from_bound = from_token.as_historical_tuple() - elif direction == "b": + elif direction == Direction.BACKWARDS: from_bound = ( None, from_token.get_max_stream_pos(), @@ -230,7 +229,7 @@ def generate_pagination_bounds( if to_token: if to_token.topological is not None: to_bound = to_token.as_historical_tuple() - elif direction == "b": + elif direction == Direction.BACKWARDS: to_bound = ( None, to_token.stream, @@ -245,20 +244,20 @@ def generate_pagination_bounds( def generate_next_token( - direction: str, last_topo_ordering: int, last_stream_ordering: int + direction: Direction, last_topo_ordering: int, last_stream_ordering: int ) -> RoomStreamToken: """ Generate the next room stream token based on the currently returned data. Args: - direction: Whether pagination is going forwards or backwards. One of "f" or "b". + direction: Whether pagination is going forwards or backwards. last_topo_ordering: The last topological ordering being returned. last_stream_ordering: The last stream ordering being returned. Returns: A new RoomStreamToken to return to the client. """ - if direction == "b": + if direction == Direction.BACKWARDS: # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk @@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, before_token, - direction="b", + direction=Direction.BACKWARDS, limit=before_limit, event_filter=event_filter, ) @@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, after_token, - direction="f", + direction=Direction.FORWARDS, limit=after_limit, event_filter=event_filter, ) @@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: @@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. @@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token if direction == "b" else from_token, - upper_token=from_token if direction == "b" else to_token, + lower_token=to_token + if direction == Direction.BACKWARDS + else from_token, + upper_token=from_token + if direction == Direction.BACKWARDS + else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, @@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: @@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 6df2de919c..5cb7875181 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -16,6 +16,7 @@ from typing import Optional import attr +from synapse.api.constants import Direction from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -34,7 +35,7 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] - direction: str + direction: Direction limit: int @classmethod @@ -45,9 +46,13 @@ class PaginationConfig: default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": - direction = parse_string( - request, "dir", default=default_dir, allowed_values=["f", "b"] + direction_str = parse_string( + request, + "dir", + default=default_dir, + allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value], ) + direction = Direction(direction_str) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") -- cgit 1.5.1 From 2a51f3ec36abeb1f5c1db795541988d1d9698e41 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 Jan 2023 10:16:21 -0500 Subject: Implement MSC3952: Intentional mentions (#14823) MSC3952 defines push rules which searches for mentions in a list of Matrix IDs in the event body, instead of searching the entire event body for display name / local part. This is implemented behind an experimental configuration flag and does not yet implement the backwards compatibility pieces of the MSC. --- changelog.d/14823.feature | 1 + rust/src/push/base_rules.rs | 21 +++++++ rust/src/push/evaluator.rs | 25 +++++++- rust/src/push/mod.rs | 34 +++++++++++ stubs/synapse/synapse_rust/push.pyi | 5 +- synapse/api/constants.py | 3 + synapse/config/experimental.py | 5 ++ synapse/push/bulk_push_rule_evaluator.py | 25 +++++++- synapse/storage/databases/main/push_rule.py | 1 + tests/push/test_bulk_push_rule_evaluator.py | 88 +++++++++++++++++++++++++++++ tests/push/test_push_rule_evaluator.py | 66 +++++++++++++++++++--- 11 files changed, 263 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14823.feature (limited to 'synapse') diff --git a/changelog.d/14823.feature b/changelog.d/14823.feature new file mode 100644 index 0000000000..8293e99eff --- /dev/null +++ b/changelog.d/14823.feature @@ -0,0 +1 @@ +Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 9140a69bb6..880eed0ef4 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -131,6 +131,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_user_mentioned"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsUserMention)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, @@ -139,6 +147,19 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mentioned"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::IsRoomMention), + Condition::Known(KnownCondition::SenderNotificationPermission { + key: Cow::Borrowed("room"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0242ee1c5f..aa71202e43 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -68,6 +68,11 @@ pub struct PushRuleEvaluator { /// The "content.body", if any. body: String, + /// The user mentions that were part of the message. + user_mentions: BTreeSet, + /// True if the message is a room message. + room_mention: bool, + /// The number of users in the room. room_member_count: u64, @@ -100,6 +105,8 @@ impl PushRuleEvaluator { #[new] pub fn py_new( flattened_keys: BTreeMap, + user_mentions: BTreeSet, + room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, @@ -116,6 +123,8 @@ impl PushRuleEvaluator { Ok(PushRuleEvaluator { flattened_keys, body, + user_mentions, + room_mention, room_member_count, notification_power_levels, sender_power_level, @@ -229,6 +238,14 @@ impl PushRuleEvaluator { KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } + KnownCondition::IsUserMention => { + if let Some(uid) = user_id { + self.user_mentions.contains(uid) + } else { + false + } + } + KnownCondition::IsRoomMention => self.room_mention, KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -424,6 +441,8 @@ fn push_rule_evaluator() { flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); let evaluator = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), BTreeMap::new(), @@ -449,6 +468,8 @@ fn test_requires_room_version_supports_condition() { let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), BTreeMap::new(), @@ -483,7 +504,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 842b13c88b..7e449f2433 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -269,6 +269,10 @@ pub enum KnownCondition { EventMatch(EventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3952.is_user_mention")] + IsUserMention, + #[serde(rename = "org.matrix.msc3952.is_room_mention")] + IsRoomMention, ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -414,6 +418,7 @@ pub struct FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3952_intentional_mentions: bool, } #[pymethods] @@ -425,6 +430,7 @@ impl FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3952_intentional_mentions: bool, ) -> Self { Self { push_rules, @@ -432,6 +438,7 @@ impl FilteredPushRules { msc1767_enabled, msc3381_polls_enabled, msc3664_enabled, + msc3952_intentional_mentions, } } @@ -465,6 +472,11 @@ impl FilteredPushRules { return false; } + if !self.msc3952_intentional_mentions && rule.rule_id.contains("org.matrix.msc3952") + { + return false; + } + true }) .map(|r| { @@ -522,6 +534,28 @@ fn test_deserialize_unstable_msc3931_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3952_user_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsUserMention) + )); +} + +#[test] +fn test_deserialize_unstable_msc3952_room_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsRoomMention) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 304ed7111c..588d90c25a 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -46,6 +46,7 @@ class FilteredPushRules: msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3952_intentional_mentions: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -55,6 +56,8 @@ class PushRuleEvaluator: def __init__( self, flattened_keys: Mapping[str, str], + user_mentions: Set[str], + room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6f9239d21c..0f224b34cd 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -233,6 +233,9 @@ class EventContentFields: # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" + # Use for mentioning users. + MSC3952_MENTIONS: Final = "org.matrix.msc3952.mentions" + # an unspecced field added to to-device messages to identify them uniquely-ish TO_DEVICE_MSGID: Final = "org.matrix.msgid" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 2590c88cde..d2d0270ddd 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -168,3 +168,8 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + + # MSC3952: Intentional mentions + self.msc3952_intentional_mentions = experimental.get( + "msc3952_intentional_mentions", False + ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f27ba64d53..deaec19564 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,13 +22,20 @@ from typing import ( List, Mapping, Optional, + Set, Tuple, Union, ) from prometheus_client import Counter -from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes +from synapse.api.constants import ( + MAIN_TIMELINE, + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event @@ -342,8 +349,24 @@ class BulkPushRuleEvaluator: for user_id, level in notification_levels.items(): notification_levels[user_id] = int(level) + # Pull out any user and room mentions. + mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) + user_mentions: Set[str] = set() + room_mention = False + if isinstance(mentions, dict): + # Remove out any non-string items and convert to a set. + user_mentions_raw = mentions.get("user_ids") + if isinstance(user_mentions_raw, list): + user_mentions = set( + filter(lambda item: isinstance(item, str), user_mentions_raw) + ) + # Room mention is only true if the value is exactly true. + room_mention = mentions.get("room") is True + evaluator = PushRuleEvaluator( _flatten_dict(event, room_version=event.room_version), + user_mentions, + room_mention, room_member_count, sender_power_level, notification_levels, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 14ca167b34..466a1145b7 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -89,6 +89,7 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, + msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions, ) return filtered_rules diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 9c17a42b65..aba62b5dc8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin @@ -126,3 +128,89 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) bulk_evaluator._action_for_event_by_user.assert_not_called() + + @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + def test_mentions(self) -> None: + """Test the behavior of an event which includes invalid mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + sentinel = object() + + def create_and_process(mentions: Any = sentinel) -> bool: + """Returns true iff the `mentions` trigger an event push action.""" + content = {} + if mentions is not sentinel: + content[EventContentFields.MSC3952_MENTIONS] = mentions + + # Create a new message event which should cause a notification. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "test", + "room_id": self.room_id, + "content": content, + "sender": f"@bob:{self.hs.hostname}", + }, + ) + ) + + # Ensure no actions are generated! + self.get_success( + bulk_evaluator.action_for_events_by_user([(event, context)]) + ) + + # If any actions are generated for this event, return true. + result = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + return len(result) > 0 + + # Not including the mentions field should not notify. + self.assertFalse(create_and_process()) + # An empty mentions field should not notify. + self.assertFalse(create_and_process({})) + + # Non-dict mentions should be ignored. + mentions: Any + for mentions in (None, True, False, 1, "foo", []): + self.assertFalse(create_and_process(mentions)) + + # A non-list should be ignored. + for mentions in (None, True, False, 1, "foo", {}): + self.assertFalse(create_and_process({"user_ids": mentions})) + + # The Matrix ID appearing anywhere in the list should notify. + self.assertTrue(create_and_process({"user_ids": [self.alice]})) + self.assertTrue(create_and_process({"user_ids": ["@another:test", self.alice]})) + + # Duplicate user IDs should notify. + self.assertTrue(create_and_process({"user_ids": [self.alice, self.alice]})) + + # Invalid entries in the list are ignored. + self.assertFalse(create_and_process({"user_ids": [None, True, False, {}, []]})) + self.assertTrue( + create_and_process({"user_ids": [None, True, False, {}, [], self.alice]}) + ) + + # Room mentions from those without power should not notify. + self.assertFalse(create_and_process({"room": True})) + + # Room mentions from those with power should notify. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"notifications": {"room": 0}}, + self.token, + state_key="", + ) + self.assertTrue(create_and_process({"room": True})) + + # Invalid data should not notify. + for mentions in (None, False, 1, "foo", [], {}): + self.assertFalse(create_and_process({"room": mentions})) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 1b87756b75..9d01c989d4 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Set, Union, cast import frozendict @@ -39,7 +39,12 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( - self, content: JsonMapping, related_events: Optional[JsonDict] = None + self, + content: JsonMapping, + *, + user_mentions: Optional[Set[str]] = None, + room_mention: bool = False, + related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -57,13 +62,15 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluator( _flatten_dict(event), + user_mentions or set(), + room_mention, room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), {} if related_events is None else related_events, - True, - event.room_version.msc3931_push_features, - True, + related_event_match_enabled=True, + room_version_feature_flags=event.room_version.msc3931_push_features, + msc3931_enabled=True, ) def test_display_name(self) -> None: @@ -90,6 +97,51 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # A display name with spaces should work fine. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + def test_user_mentions(self) -> None: + """Check for user mentions.""" + condition = {"kind": "org.matrix.msc3952.is_user_mention"} + + # No mentions shouldn't match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # An empty set shouldn't match + evaluator = self._get_evaluator({}, user_mentions=set()) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # The Matrix ID appearing anywhere in the mentions list should match + evaluator = self._get_evaluator({}, user_mentions={"@user:test"}) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + evaluator = self._get_evaluator( + {}, user_mentions={"@another:test", "@user:test"} + ) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + + def test_room_mentions(self) -> None: + """Check for room mentions.""" + condition = {"kind": "org.matrix.msc3952.is_room_mention"} + + # No room mention shouldn't match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, None, None)) + + # Room mention should match. + evaluator = self._get_evaluator({}, room_mention=True) + self.assertTrue(evaluator.matches(condition, None, None)) + + # A room mention and user mention is valid. + evaluator = self._get_evaluator( + {}, user_mentions={"@another:test"}, room_mention=True + ) + self.assertTrue(evaluator.matches(condition, None, None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: @@ -308,7 +360,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", @@ -408,7 +460,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", -- cgit 1.5.1 From 510d4b06e7d346b4f94cb5598da90c9f668b62bb Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Jan 2023 21:29:30 +0000 Subject: Handle malformed values of `notification.room` in power level events (#14942) * Better test for bad values in power levels events The previous test only checked that Synapse didn't raise an exception, but didn't check that we had correctly interpreted the value of the dodgy power level. It also conflated two things: bad room notification levels, and bad user levels. There _is_ logic for converting the latter to integers, but we should test it separately. * Check we ignore types that don't convert to int * Handle `None` values in `notifications.room` * Changelog * Also test that bad values are rejected by event auth * Docstring * linter scripttttttttt --- changelog.d/14942.bugfix | 1 + synapse/push/bulk_push_rule_evaluator.py | 19 +++++- tests/push/test_bulk_push_rule_evaluator.py | 93 +++++++++++++++++++++++++---- tests/test_event_auth.py | 32 +++++++++- 4 files changed, 128 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14942.bugfix (limited to 'synapse') diff --git a/changelog.d/14942.bugfix b/changelog.d/14942.bugfix new file mode 100644 index 0000000000..a3ca3eb7e9 --- /dev/null +++ b/changelog.d/14942.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.68.0 where we were unable to service remote joins in rooms with `@room` notification levels set to `null` in their (malformed) power levels. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index deaec19564..88cfc05d05 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -69,6 +69,9 @@ STATE_EVENT_TYPES_TO_MARK_UNREAD = { } +SENTINEL = object() + + def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: # Exclude rejected and soft-failed events. if context.rejected or event.internal_metadata.is_soft_failed(): @@ -343,11 +346,21 @@ class BulkPushRuleEvaluator: related_events = await self._related_events(event) # It's possible that old room versions have non-integer power levels (floats or - # strings). Workaround this by explicitly converting to int. + # strings; even the occasional `null`). For old rooms, we interpret these as if + # they were integers. Do this here for the `@room` power level threshold. + # Note that this is done automatically for the sender's power level by + # _get_power_levels_and_sender_level in its call to get_user_power_level + # (even for room V10.) notification_levels = power_levels.get("notifications", {}) if not event.room_version.msc3667_int_only_power_levels: - for user_id, level in notification_levels.items(): - notification_levels[user_id] = int(level) + keys = list(notification_levels.keys()) + for key in keys: + level = notification_levels.get(key, SENTINEL) + if level is not SENTINEL and type(level) is not int: + try: + notification_levels[key] = int(level) + except (TypeError, ValueError): + del notification_levels[key] # Pull out any user and room mentions. mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index aba62b5dc8..fda48d9f61 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -15,6 +15,8 @@ from typing import Any from unittest.mock import patch +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventContentFields @@ -48,35 +50,84 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.requester = create_requester(self.alice) self.room_id = self.helper.create_room_as( - self.alice, room_version=RoomVersions.V9.identifier, tok=self.token + # This is deliberately set to V9, because we want to test the logic which + # handles stringy power levels. Stringy power levels were outlawed in V10. + self.alice, + room_version=RoomVersions.V9.identifier, + tok=self.token, ) self.event_creation_handler = self.hs.get_event_creation_handler() - def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: - """We should convert floats and strings to integers before passing to Rust. + @parameterized.expand( + [ + # The historically-permitted bad values. Alice's notification should be + # allowed if this threshold is at or below her power level (60) + ("100", False), + ("0", True), + (12.34, True), + (60.0, True), + (67.89, False), + # Values that int(...) would not successfully cast should be ignored. + # The room notification level should then default to 50, per the spec, so + # Alice's notification is allowed. + (None, True), + # We haven't seen `"room": []` or `"room": {}` in the wild (yet), but + # let's check them for paranoia's sake. + ([], True), + ({}, True), + ] + ) + def test_action_for_event_by_user_handles_noninteger_room_power_levels( + self, bad_room_level: object, should_permit: bool + ) -> None: + """We should convert strings in `room` to integers before passing to Rust. + + Test this as follows: + - Create a room as Alice and invite two other users Bob and Charlie. + - Set PLs so that Alice has PL 60 and `notifications.room` is set to a bad value. + - Have Alice create a message notifying @room. + - Evaluate notification actions for that message. This should not raise. + - Look in the DB to see if that message triggered a highlight for Bob. + + The test is parameterised with two arguments: + - the bad power level value for "room", before JSON serisalistion + - whether Bob should expect the message to be highlighted Reproduces #14060. A lack of validation: the gift that keeps on giving. """ - - # Alter the power levels in that room to include stringy and floaty levels. - # We need to suppress the validation logic or else it will reject these dodgy - # values. (Presumably this validation was not always present.) + # Join another user to the room, so that there is someone to see Alice's + # @room notification. + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + self.helper.join(self.room_id, bob, tok=bob_token) + + # Alter the power levels in that room to include the bad @room notification + # level. We need to suppress + # + # - canonicaljson validation, because canonicaljson forbids floats; + # - the event jsonschema validation, because it will forbid bad values; and + # - the auth rules checks, because they stop us from creating power levels + # with `"room": null`. (We want to test this case, because we have seen it + # in the wild.) + # + # We have seen stringy and null values for "room" in the wild, so presumably + # some of this validation was missing in the past. with patch("synapse.events.validator.validate_canonicaljson"), patch( "synapse.events.validator.jsonschema.validate" - ): - self.helper.send_state( + ), patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"): + pl_event_id = self.helper.send_state( self.room_id, "m.room.power_levels", { - "users": {self.alice: "100"}, # stringy - "notifications": {"room": 100.0}, # float + "users": {self.alice: 60}, + "notifications": {"room": bad_room_level}, }, self.token, state_key="", - ) + )["event_id"] # Create a new message event, and try to evaluate it under the dodgy # power level event. @@ -88,10 +139,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): "room_id": self.room_id, "content": { "msgtype": "m.text", - "body": "helo", + "body": "helo @room", }, "sender": self.alice, }, + prev_event_ids=[pl_event_id], ) ) @@ -99,6 +151,21 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # should not raise self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + # Did Bob see Alice's @room notification? + highlighted_actions = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + "user_id": bob, + "highlight": 1, + }, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + self.assertEqual(len(highlighted_actions), int(should_permit)) + @override_config({"push": {"enabled": False}}) def test_action_for_event_by_user_disabled_by_config(self) -> None: """Ensure that push rules are not calculated when disabled in the config""" diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index f4d9fba0a1..0a7937f1cc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from typing import Collection, Dict, Iterable, List, Optional +from typing import Any, Collection, Dict, Iterable, List, Optional from parameterized import parameterized @@ -728,6 +728,36 @@ class EventAuthTestCase(unittest.TestCase): pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event} ) + def test_room_v10_rejects_other_non_integer_power_levels(self) -> None: + """We should reject PLs that are non-integer, non-string JSON values. + + test_room_v10_rejects_string_power_levels above handles the string case. + """ + + def create_event(pl_event_content: Dict[str, Any]) -> EventBase: + return make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + contents: Iterable[Dict[str, Any]] = [ + {"notifications": {"room": None}}, + {"users": {"@alice:wonderland": []}}, + {"users_default": {}}, + ] + for content in contents: + event = create_event(content) + with self.assertRaises(SynapseError): + event_auth._check_power_levels(event.room_version, event, {}) + # helpers for making events TEST_DOMAIN = "example.com" -- cgit 1.5.1 From 796a4b74823b721c72de07e45718f05e78e1565d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Jan 2023 10:33:07 +0000 Subject: Prefer `type(x) is int` to `isinstance(x, int)` (#14945) * Perfer `type(x) is int` to `isinstance(x, int)` This covered all additional instances I could see where `x` was user-controlled. The remaining cases are ``` $ rg -s 'isinstance.*[^_]int' tests/replication/_base.py 576: if isinstance(obj, int): synapse/util/caches/stream_change_cache.py 136: assert isinstance(stream_pos, int) 214: assert isinstance(stream_pos, int) 246: assert isinstance(stream_pos, int) 267: assert isinstance(stream_pos, int) synapse/replication/tcp/external_cache.py 133: if isinstance(result, int): synapse/metrics/__init__.py 100: if isinstance(calls, (int, float)): synapse/handlers/appservice.py 262: assert isinstance(new_token, int) synapse/config/_util.py 62: if isinstance(p, int): ``` which cover metrics, logic related to `jsonschema`, and replication and data streams. AFAICS these are all internal to Synapse * Changelog --- changelog.d/14945.misc | 1 + synapse/config/_base.py | 72 +++++++++++++++++++++---------- synapse/config/cache.py | 4 +- synapse/config/server.py | 2 +- synapse/events/validator.py | 4 +- synapse/federation/federation_client.py | 2 +- synapse/handlers/message.py | 2 +- synapse/rest/admin/__init__.py | 2 +- synapse/rest/admin/registration_tokens.py | 15 +++---- synapse/rest/admin/users.py | 6 +-- synapse/rest/client/report_event.py | 2 +- synapse/rest/media/v1/oembed.py | 2 +- synapse/rest/media/v1/thumbnailer.py | 2 +- synapse/storage/databases/main/events.py | 6 +-- 14 files changed, 75 insertions(+), 47 deletions(-) create mode 100644 changelog.d/14945.misc (limited to 'synapse') diff --git a/changelog.d/14945.misc b/changelog.d/14945.misc new file mode 100644 index 0000000000..654174f9a8 --- /dev/null +++ b/changelog.d/14945.misc @@ -0,0 +1 @@ +Fix various long-standing bugs in Synapse's config, event and request handling where booleans were unintentionally accepted where an integer was expected. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1f6362aedd..2ce60610ca 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -174,15 +174,29 @@ class Config: @staticmethod def parse_size(value: Union[str, int]) -> int: - if isinstance(value, int): + """Interpret `value` as a number of bytes. + + If an integer is provided it is treated as bytes and is unchanged. + + String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and + mebibytes respectively. No suffix is understood as a plain byte count. + + Raises: + TypeError, if given something other than an integer or a string + ValueError: if given a string not of the form described above. + """ + if type(value) is int: return value - sizes = {"K": 1024, "M": 1024 * 1024} - size = 1 - suffix = value[-1] - if suffix in sizes: - value = value[:-1] - size = sizes[suffix] - return int(value) * size + elif type(value) is str: + sizes = {"K": 1024, "M": 1024 * 1024} + size = 1 + suffix = value[-1] + if suffix in sizes: + value = value[:-1] + size = sizes[suffix] + return int(value) * size + else: + raise TypeError(f"Bad byte size {value!r}") @staticmethod def parse_duration(value: Union[str, int]) -> int: @@ -198,22 +212,36 @@ class Config: Returns: The number of milliseconds in the duration. + + Raises: + TypeError, if given something other than an integer or a string + ValueError: if given a string not of the form described above. """ - if isinstance(value, int): + if type(value) is int: return value - second = 1000 - minute = 60 * second - hour = 60 * minute - day = 24 * hour - week = 7 * day - year = 365 * day - sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} - size = 1 - suffix = value[-1] - if suffix in sizes: - value = value[:-1] - size = sizes[suffix] - return int(value) * size + elif type(value) is str: + second = 1000 + minute = 60 * second + hour = 60 * minute + day = 24 * hour + week = 7 * day + year = 365 * day + sizes = { + "s": second, + "m": minute, + "h": hour, + "d": day, + "w": week, + "y": year, + } + size = 1 + suffix = value[-1] + if suffix in sizes: + value = value[:-1] + size = sizes[suffix] + return int(value) * size + else: + raise TypeError(f"Bad duration {value!r}") @staticmethod def abspath(file_path: str) -> str: diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 015b2a138e..05f69cb1ba 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -126,7 +126,7 @@ class CacheConfig(Config): cache_config = config.get("caches") or {} self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE) - if not isinstance(self.global_factor, (int, float)): + if type(self.global_factor) not in (int, float): raise ConfigError("caches.global_factor must be a number.") # Load cache factors from the config @@ -151,7 +151,7 @@ class CacheConfig(Config): ) for cache, factor in individual_factors.items(): - if not isinstance(factor, (int, float)): + if type(factor) not in (int, float): raise ConfigError( "caches.per_cache_factors.%s must be a number" % (cache,) ) diff --git a/synapse/config/server.py b/synapse/config/server.py index 80bcfa4080..ecdaa2d9dd 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -904,7 +904,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type")) port = listener.get("port") - if not isinstance(port, int): + if type(port) is not int: raise ConfigError("Listener configuration is lacking a valid 'port' option") tls = listener.get("tls", False) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index a6f0104396..fb1737b910 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -139,7 +139,7 @@ class EventValidator: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if not isinstance(min_lifetime, int): + if type(min_lifetime) is not int: raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -147,7 +147,7 @@ class EventValidator: ) if max_lifetime is not None: - if not isinstance(max_lifetime, int): + if type(max_lifetime) is not int: raise SynapseError( code=400, msg="'max_lifetime' must be an integer", diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f185b6c1f9..feb32e40e5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1864,7 +1864,7 @@ class TimestampToEventResponse: ) origin_server_ts = d.get("origin_server_ts") - if not isinstance(origin_server_ts, int): + if type(origin_server_ts) is not int: raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3278a695ed..6290f7f523 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -377,7 +377,7 @@ class MessageHandler: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if not isinstance(expiry_ts, int) or event.is_state(): + if type(expiry_ts) is not int or event.is_state(): return # _schedule_expiry_for_event won't actually schedule anything if there's already diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index fb73886df0..79f22a59f1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -152,7 +152,7 @@ class PurgeHistoryRestServlet(RestServlet): logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if not isinstance(ts, int): + if type(ts) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index af606e9252..95e751288b 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet): else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if not isinstance(length, int): + if type(length) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,8 +163,7 @@ class NewRegistrationTokenRestServlet(RestServlet): uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None - or (isinstance(uses_allowed, int) and uses_allowed >= 0) + uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -173,13 +172,13 @@ class NewRegistrationTokenRestServlet(RestServlet): ) expiry_time = body.get("expiry_time", None) - if not isinstance(expiry_time, (int, type(None))): + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -284,7 +283,7 @@ class RegistrationTokenRestServlet(RestServlet): uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (isinstance(uses_allowed, int) and uses_allowed >= 0) + or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -295,13 +294,13 @@ class RegistrationTokenRestServlet(RestServlet): if "expiry_time" in body: expiry_time = body["expiry_time"] - if not isinstance(expiry_time, (int, type(None))): + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6e0c44be2a..0841b89c1a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -973,7 +973,7 @@ class UserTokenRestServlet(RestServlet): body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") - if valid_until_ms and not isinstance(valid_until_ms, int): + if type(valid_until_ms) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" ) @@ -1125,14 +1125,14 @@ class RateLimitRestServlet(RestServlet): messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if not isinstance(messages_per_second, int) or messages_per_second < 0: + if type(messages_per_second) is not int or messages_per_second < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if not isinstance(burst_count, int) or burst_count < 0: + if type(burst_count) is not int or burst_count < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 6e962a4532..e2b410cf32 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -54,7 +54,7 @@ class ReportEventRestServlet(RestServlet): "Param 'reason' must be a string", Codes.BAD_JSON, ) - if not isinstance(body.get("score", 0), int): + if type(body.get("score", 0)) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index a3738a6250..7592aa5d47 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -200,7 +200,7 @@ class OEmbedProvider: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if val is not None and isinstance(val, int): + if type(val) is int: open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a48a4de92a..9480cc5763 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -77,7 +77,7 @@ class Thumbnailer: image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert isinstance(image_orientation, int) + assert type(image_orientation) is int self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0f097a2927..1536937b67 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1651,7 +1651,7 @@ class PersistEventsStore: if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if isinstance(expiry_ts, int) and not event.is_state(): + if type(expiry_ts) is int and not event.is_state(): self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -2133,10 +2133,10 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), int) + and type(event.content["min_lifetime"]) is not int ) or ( "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), int) + and type(event.content["max_lifetime"]) is not int ): # Ignore the event if one of the value isn't an integer. return -- cgit 1.5.1 From a134e626e43e9c31a4618d4164ba7d6242c0f803 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Jan 2023 10:57:02 +0000 Subject: Reject boolean power levels (#14944) * Better test for bad values in power levels events The previous test only checked that Synapse didn't raise an exception, but didn't check that we had correctly interpreted the value of the dodgy power level. It also conflated two things: bad room notification levels, and bad user levels. There _is_ logic for converting the latter to integers, but we should test it separately. * Check we ignore types that don't convert to int * Handle `None` values in `notifications.room` * Changelog * Also test that bad values are rejected by event auth * Docstring * linter scripttttttttt * Test boolean values in PL content * Reject boolean power levels * Changelog --- changelog.d/14944.bugfix | 1 + synapse/event_auth.py | 4 ++-- synapse/events/utils.py | 6 +++--- synapse/federation/federation_base.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14944.bugfix (limited to 'synapse') diff --git a/changelog.d/14944.bugfix b/changelog.d/14944.bugfix new file mode 100644 index 0000000000..5fe1fb322b --- /dev/null +++ b/changelog.d/14944.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.64 where boolean power levels were erroneously permitted in [v10 rooms](https://spec.matrix.org/v1.5/rooms/v10/). diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c4a7b16413..e0be9f88cc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -875,11 +875,11 @@ def _check_power_levels( "kick", "invite", }: - if not isinstance(v, int): + if type(v) is not int: raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - isinstance(v, int) for v in v.values() + type(v) is int for v in v.values() ): raise SynapseError( 400, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 52e4b467e8..ebf8c7ed83 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -648,10 +648,10 @@ def _copy_power_level_value_as_integer( ) -> None: """Set `power_levels[key]` to the integer represented by `old_value`. - :raises TypeError: if `old_value` is not an integer, nor a base-10 string + :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if isinstance(old_value, int): + if type(old_value) is int: power_levels[key] = old_value return @@ -679,7 +679,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if isinstance(value, int): + if type(value) is int: if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 6bd4742140..29fae716f5 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if not isinstance(depth, int): + if type(depth) is not int: raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: -- cgit 1.5.1 From 6d14fdc2710688014a7a66cc48485462c6e86a1e Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:03:55 +0000 Subject: Make sqlite database migrations transactional again, part two (#14926) #14910 fixed the regression introduced by #13873 where sqlite database migrations would no longer run inside a transaction. However, it committed the transaction before Synapse updated its bookkeeping of which migrations have been run, which means that migrations may be run again after they have completed successfully. Leave the transaction open at the end of `executescript`, to restore the old, correct behaviour. Also make the PostgreSQL behaviour consistent with SQLite. Fixes #14909. Signed-off-by: Sean Quah --- changelog.d/14926.bugfix | 1 + synapse/storage/engines/_base.py | 5 +- synapse/storage/engines/postgres.py | 6 ++- synapse/storage/engines/sqlite.py | 6 ++- tests/storage/test_database.py | 96 +++++++++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 changelog.d/14926.bugfix (limited to 'synapse') diff --git a/changelog.d/14926.bugfix b/changelog.d/14926.bugfix new file mode 100644 index 0000000000..f1f34cd6ba --- /dev/null +++ b/changelog.d/14926.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in Synapse 1.69.0 which can result in database corruption when database migrations are interrupted on sqlite. diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index bc9ca3a53c..0363cdc038 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -133,8 +133,9 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM This is not provided by DBAPI2, and so needs engine-specific support. - Some database engines may automatically COMMIT the ongoing transaction both - before and after executing the script. + Any ongoing transaction is committed before executing the script in its own + transaction. The script transaction is left open and it is the responsibility of + the caller to commit it. """ ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index f9f562ea45..b350f57ccb 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -220,5 +220,9 @@ class PostgresEngine( """Execute a chunk of SQL containing multiple semicolon-delimited statements. Psycopg2 seems happy to do this in DBAPI2's `execute()` function. + + For consistency with SQLite, any ongoing transaction is committed before + executing the script in its own transaction. The script transaction is + left open and it is the responsibility of the caller to commit it. """ - cursor.execute(script) + cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 2f7df85ce4..28751e89a5 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -135,14 +135,16 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): > than one statement with it, it will raise a Warning. Use executescript() if > you want to execute multiple SQL statements with one call. - The script is wrapped in transaction control statemnets, since the docs for + The script is prefixed with a `BEGIN TRANSACTION`, since the docs for `executescript` warn: > If there is a pending transaction, an implicit COMMIT statement is executed > first. No other implicit transaction control is performed; any transaction > control must be added to sql_script. """ - cursor.executescript(f"BEGIN TRANSACTION;\n{script}\nCOMMIT;") + # The implementation of `executescript` can be found at + # https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035. + cursor.executescript(f"BEGIN TRANSACTION; {script}") # Following functions taken from: https://github.com/coleifer/peewee diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 543cce6b3e..8cd7c89ca2 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -37,6 +38,101 @@ class TupleComparisonClauseTestCase(unittest.TestCase): self.assertEqual(args, [1, 2]) +class ExecuteScriptTestCase(unittest.HomeserverTestCase): + """Tests for `BaseDatabaseEngine.executescript` implementations.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + self.get_success( + self.db_pool.runInteraction( + "create", + lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"), + ) + ) + + def test_transaction(self) -> None: + """Test that all statements are run in a single transaction.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_transaction") + self.db_pool.engine.executescript( + cur, + ";".join( + [ + "INSERT INTO foo (name) VALUES ('transaction test')", + # This next statement will fail. When `executescript` is not + # transactional, the previous row will be observed later. + "INSERT INTO foo (name) VALUES ('transaction test')", + ] + ), + ) + + self.get_failure( + self.db_pool.runWithConnection(run), + self.db_pool.engine.module.IntegrityError, + ) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "transaction test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not running statements inside a transaction", + ) + + def test_commit(self) -> None: + """Test that the script transaction remains open and can be committed.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_commit") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('commit test')" + ) + cur.execute("COMMIT") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNotNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "commit test"}, + retcol="name", + allow_none=True, + ) + ), + ) + + def test_rollback(self) -> None: + """Test that the script transaction remains open and can be rolled back.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_rollback") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('rollback test')" + ) + cur.execute("ROLLBACK") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "rollback test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not leaving the script transaction open", + ) + + class CallbacksTestCase(unittest.HomeserverTestCase): """Tests for transaction callbacks.""" -- cgit 1.5.1 From 805b641fb6b31e677278eaf6e27875eba5c2a3d3 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:31:52 +0000 Subject: Fix "Re-starting finished log context" spam when creating events (#14947) `run_in_background` calls re-use the current logging context. When they are not awaited, they can complete after the current logging context has been marked as finished, which leads to log spam. Use `run_as_background_process` instead. Fixes one of the instances of #13090. Signed-off-by: Sean Quah --- changelog.d/14947.bugfix | 1 + synapse/handlers/message.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14947.bugfix (limited to 'synapse') diff --git a/changelog.d/14947.bugfix b/changelog.d/14947.bugfix new file mode 100644 index 0000000000..b9e768c44c --- /dev/null +++ b/changelog.d/14947.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where sending messages on servers with presence enabled would spam "Re-starting finished log context" log lines. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6290f7f523..e688e00575 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1939,7 +1939,9 @@ class EventCreationHandler: if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. - run_in_background(self._bump_active_time, requester.user) + run_as_background_process( + "bump_presence_active_time", self._bump_active_time, requester.user + ) async def _notify() -> None: try: -- cgit 1.5.1 From 3b8574b4f250bac1e4d4cfbf6b1ceec83bc0bac2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Jan 2023 12:43:20 +0000 Subject: Tag /send_join responses to detect faster joins (#14950) * Tag /send_join responses to detect faster joins * Changelog * Define a proper SynapseTag * isort --- changelog.d/14950.misc | 1 + synapse/federation/federation_server.py | 6 ++++++ synapse/logging/opentracing.py | 5 +++++ 3 files changed, 12 insertions(+) create mode 100644 changelog.d/14950.misc (limited to 'synapse') diff --git a/changelog.d/14950.misc b/changelog.d/14950.misc new file mode 100644 index 0000000000..6602776b3f --- /dev/null +++ b/changelog.d/14950.misc @@ -0,0 +1 @@ +Faster joins: tag `v2/send_join/` requests to indicate if they served a partial join response. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3197939a36..c9a6dfd1a4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -62,7 +62,9 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.opentracing import ( + SynapseTags, log_kv, + set_tag, start_active_span_from_edu, tag_args, trace, @@ -678,6 +680,10 @@ class FederationServer(FederationBase): room_id: str, caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: + set_tag( + SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE, + caller_supports_partial_state, + ) await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] requester=None, key=room_id, diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index a705af8356..8ef9a0dda8 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -322,6 +322,11 @@ class SynapseTags: # The name of the external cache CACHE_NAME = "cache.name" + # Boolean. Present on /v2/send_join requests, omitted from all others. + # True iff partial state was requested and we provided (or intended to provide) + # partial state in the response. + SEND_JOIN_RESPONSE_IS_PARTIAL_STATE = "send_join.partial_state_response" + # Used to tag function arguments # # Tag a named arg. The name of the argument should be appended to this prefix. -- cgit 1.5.1 From bf82b56babc9e2cacba34f8878da3b3834914b3a Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 1 Feb 2023 16:45:19 +0100 Subject: Add more user information to export-data command. (#14894) * The user's profile information. * The user's devices. * The user's connections / IP address information. --- .ci/scripts/test_export_data_command.sh | 10 +++-- changelog.d/14894.feature | 1 + docs/usage/administration/admin_faq.md | 80 ++++++++++++++++++++++++++------- synapse/app/admin_cmd.py | 32 ++++++++++++- synapse/handlers/admin.py | 43 ++++++++++++++++++ tests/handlers/test_admin.py | 60 +++++++++++++++++++++++++ 6 files changed, 206 insertions(+), 20 deletions(-) create mode 100644 changelog.d/14894.feature (limited to 'synapse') diff --git a/.ci/scripts/test_export_data_command.sh b/.ci/scripts/test_export_data_command.sh index 9f6c49acff..36f836345c 100755 --- a/.ci/scripts/test_export_data_command.sh +++ b/.ci/scripts/test_export_data_command.sh @@ -23,8 +23,9 @@ poetry run python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-dat --output-directory /tmp/export_data # Test that the output directory exists and contains the rooms directory -dir="/tmp/export_data/rooms" -if [ -d "$dir" ]; then +dir_r="/tmp/export_data/rooms" +dir_u="/tmp/export_data/user_data" +if [ -d "$dir_r" ] && [ -d "$dir_u" ]; then echo "Command successful, this test passes" else echo "No output directories found, the command fails against a sqlite database." @@ -43,8 +44,9 @@ poetry run python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-d --output-directory /tmp/export_data2 # Test that the output directory exists and contains the rooms directory -dir2="/tmp/export_data2/rooms" -if [ -d "$dir2" ]; then +dir_r2="/tmp/export_data2/rooms" +dir_u2="/tmp/export_data2/user_data" +if [ -d "$dir_r2" ] && [ -d "$dir_u2" ]; then echo "Command successful, this test passes" else echo "No output directories found, the command fails against a postgres database." diff --git a/changelog.d/14894.feature b/changelog.d/14894.feature new file mode 100644 index 0000000000..d22741d079 --- /dev/null +++ b/changelog.d/14894.feature @@ -0,0 +1 @@ +Adds profile information, devices and connections to the user data export via command line. \ No newline at end of file diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md index 18ce6171db..7a27741199 100644 --- a/docs/usage/administration/admin_faq.md +++ b/docs/usage/administration/admin_faq.md @@ -2,13 +2,19 @@ How do I become a server admin? --- -If your server already has an admin account you should use the [User Admin API](../../admin_api/user_admin_api.md#change-whether-a-user-is-a-server-administrator-or-not) to promote other accounts to become admins. +If your server already has an admin account you should use the +[User Admin API](../../admin_api/user_admin_api.md#change-whether-a-user-is-a-server-administrator-or-not) +to promote other accounts to become admins. -If you don't have any admin accounts yet you won't be able to use the admin API, so you'll have to edit the database manually. Manually editing the database is generally not recommended so once you have an admin account: use the admin APIs to make further changes. +If you don't have any admin accounts yet you won't be able to use the admin API, +so you'll have to edit the database manually. Manually editing the database is +generally not recommended so once you have an admin account: use the admin APIs +to make further changes. ```sql UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; ``` + What servers are my server talking to? --- Run this sql query on your db: @@ -36,8 +42,38 @@ How can I export user data? --- Synapse includes a Python command to export data for a specific user. It takes the homeserver configuration file and the full Matrix ID of the user to export: + ```console -python -m synapse.app.admin_cmd -c export-data +python -m synapse.app.admin_cmd -c export-data --output-directory +``` + +If you uses [Poetry](../../development/dependencies.md#managing-dependencies-with-poetry) +to run Synapse: + +```console +poetry run python -m synapse.app.admin_cmd -c export-data --output-directory +``` + +The directory to store the export data in can be customised with the +`--output-directory` parameter; ensure that the provided directory is +empty. If this parameter is not provided, Synapse defaults to creating +a temporary directory (which starts with "synapse-exfiltrate") in `/tmp`, +`/var/tmp`, or `/usr/tmp`, in that order. + +The exported data has the following layout: + +``` +output-directory +├───rooms +│ └─── +│ ├───events +│ ├───state +│ ├───invite_state +│ └───knock_state +└───user_data + ├───connections + ├───devices + └───profile ``` Manually resetting passwords @@ -50,21 +86,29 @@ I have a problem with my server. Can I just delete my database and start again? --- Deleting your database is unlikely to make anything better. -It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated network: lots of other servers have information about your server. +It's easy to make the mistake of thinking that you can start again from a clean +slate by dropping your database, but things don't work like that in a federated +network: lots of other servers have information about your server. -For example: other servers might think that you are in a room, your server will think that you are not, and you'll probably be unable to interact with that room in a sensible way ever again. +For example: other servers might think that you are in a room, your server will +think that you are not, and you'll probably be unable to interact with that room +in a sensible way ever again. -In general, there are better solutions to any problem than dropping the database. Come and seek help in https://matrix.to/#/#synapse:matrix.org. +In general, there are better solutions to any problem than dropping the database. +Come and seek help in https://matrix.to/#/#synapse:matrix.org. There are two exceptions when it might be sensible to delete your database and start again: -* You have *never* joined any rooms which are federated with other servers. For instance, a local deployment which the outside world can't talk to. -* You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database. +* You have *never* joined any rooms which are federated with other servers. For +instance, a local deployment which the outside world can't talk to. +* You are changing the `server_name` in the homeserver configuration. In effect +this makes your server a completely new one from the point of view of the network, +so in this case it makes sense to start with a clean database. (In both cases you probably also want to clear out the media_store.) I've stuffed up access to my room, how can I delete it to free up the alias? --- Using the following curl command: -``` +```console curl -H 'Authorization: Bearer ' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/ ``` `` - can be obtained in riot by looking in the riot settings, down the bottom is: @@ -75,19 +119,25 @@ Access Token:\ How can I find the lines corresponding to a given HTTP request in my homeserver log? --- -Synapse tags each log line according to the HTTP request it is processing. When it finishes processing each request, it logs a line containing the words `Processed request: `. For example: +Synapse tags each log line according to the HTTP request it is processing. When +it finishes processing each request, it logs a line containing the words +`Processed request: `. For example: ``` 2019-02-14 22:35:08,196 - synapse.access.http.8008 - 302 - INFO - GET-37 - ::1 - 8008 - {@richvdh:localhost} Processed request: 0.173sec/0.001sec (0.002sec, 0.000sec) (0.027sec/0.026sec/2) 687B 200 "GET /_matrix/client/r0/sync HTTP/1.1" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Safari/537.36" [0 dbevts]" ``` -Here we can see that the request has been tagged with `GET-37`. (The tag depends on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, `OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do: +Here we can see that the request has been tagged with `GET-37`. (The tag depends +on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, +`OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do: -``` +```console grep 'GET-37' homeserver.log ``` -If you want to paste that output into a github issue or matrix room, please remember to surround it with triple-backticks (```) to make it legible (see [quoting code](https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code)). +If you want to paste that output into a github issue or matrix room, please +remember to surround it with triple-backticks (```) to make it legible +(see [quoting code](https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code)). What do all those fields in the 'Processed' line mean? @@ -127,7 +177,7 @@ This is normally caused by a misconfiguration in your reverse-proxy. See [the re Help!! Synapse is slow and eats all my RAM/CPU! ------------------------------------------------ +--- First, ensure you are running the latest version of Synapse, using Python 3 with a [PostgreSQL database](../../postgres.md). @@ -169,7 +219,7 @@ in the Synapse config file: [see here](../configuration/config_documentation.md# Running out of File Handles ---------------------------- +--- If Synapse runs out of file handles, it typically fails badly - live-locking at 100% CPU, and/or failing to accept new TCP connections (blocking the diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 165d1c5db0..fe7afb9475 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -35,6 +35,7 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) +from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.event_federation import EventFederationWorkerStore @@ -43,6 +44,7 @@ from synapse.storage.databases.main.event_push_actions import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore @@ -54,7 +56,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import StateMap +from synapse.types import JsonDict, StateMap from synapse.util import SYNAPSE_VERSION from synapse.util.logcontext import LoggingContext @@ -63,6 +65,7 @@ logger = logging.getLogger("synapse.app.admin_cmd") class AdminCmdSlavedStore( FilteringWorkerStore, + ClientIpWorkerStore, DeviceWorkerStore, TagsWorkerStore, DeviceInboxWorkerStore, @@ -82,6 +85,7 @@ class AdminCmdSlavedStore( EventsWorkerStore, RegistrationWorkerStore, RoomWorkerStore, + ProfileWorkerStore, ): def __init__( self, @@ -192,6 +196,32 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) + def write_profile(self, profile: JsonDict) -> None: + user_directory = os.path.join(self.base_directory, "user_data") + os.makedirs(user_directory, exist_ok=True) + profile_file = os.path.join(user_directory, "profile") + + with open(profile_file, "a") as f: + print(json.dumps(profile), file=f) + + def write_devices(self, devices: List[JsonDict]) -> None: + user_directory = os.path.join(self.base_directory, "user_data") + os.makedirs(user_directory, exist_ok=True) + device_file = os.path.join(user_directory, "devices") + + for device in devices: + with open(device_file, "a") as f: + print(json.dumps(device), file=f) + + def write_connections(self, connections: List[JsonDict]) -> None: + user_directory = os.path.join(self.base_directory, "user_data") + os.makedirs(user_directory, exist_ok=True) + connection_file = os.path.join(user_directory, "connections") + + for connection in connections: + with open(connection_file, "a") as f: + print(json.dumps(connection), file=f) + def finished(self) -> str: return self.base_directory diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index c81ea34758..b03c214b14 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._device_handler = hs.get_device_handler() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state self._msc3866_enabled = hs.config.experimental.msc3866.enabled @@ -247,6 +248,21 @@ class AdminHandler: ) writer.write_state(room_id, event_id, state) + # Get the user profile + profile = await self.get_user(UserID.from_string(user_id)) + if profile is not None: + writer.write_profile(profile) + + # Get all devices the user has + devices = await self._device_handler.get_devices_by_user(user_id) + writer.write_devices(devices) + + # Get all connections the user has + connections = await self.get_whois(UserID.from_string(user_id)) + writer.write_connections( + connections["devices"][""]["sessions"][0]["connections"] + ) + return writer.finished() @@ -297,6 +313,33 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): """ raise NotImplementedError() + @abc.abstractmethod + def write_profile(self, profile: JsonDict) -> None: + """Write the profile of a user. + + Args: + profile: The user profile. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_devices(self, devices: List[JsonDict]) -> None: + """Write the devices of a user. + + Args: + devices: The list of devices. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_connections(self, connections: List[JsonDict]) -> None: + """Write the connections of a user. + + Args: + connections: The list of connections / sessions. + """ + raise NotImplementedError() + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index c1579dac61..6f300b8e11 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -38,6 +38,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_handler = hs.get_admin_handler() + self._store = hs.get_datastores().main self.user1 = self.register_user("user1", "password") self.token1 = self.login("user1", "password") @@ -236,3 +237,62 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[0], room_id) self.assertEqual(args[1].content["membership"], "knock") self.assertTrue(args[2]) # Assert there is at least one bit of state + + def test_profile(self) -> None: + """Tests that user profile get exported.""" + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_profile.assert_called_once() + + # check only a few values, not all available + args = writer.write_profile.call_args[0] + self.assertEqual(args[0]["name"], self.user2) + self.assertIn("displayname", args[0]) + self.assertIn("avatar_url", args[0]) + self.assertIn("threepids", args[0]) + self.assertIn("external_ids", args[0]) + self.assertIn("creation_ts", args[0]) + + def test_devices(self) -> None: + """Tests that user devices get exported.""" + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_devices.assert_called_once() + + args = writer.write_devices.call_args[0] + self.assertEqual(len(args[0]), 1) + self.assertEqual(args[0][0]["user_id"], self.user2) + self.assertIn("device_id", args[0][0]) + self.assertIsNone(args[0][0]["display_name"]) + self.assertIsNone(args[0][0]["last_seen_user_agent"]) + self.assertIsNone(args[0][0]["last_seen_ts"]) + self.assertIsNone(args[0][0]["last_seen_ip"]) + + def test_connections(self) -> None: + """Tests that user sessions / connections get exported.""" + # Insert a user IP + self.get_success( + self._store.insert_client_ip( + self.user2, "access_token", "ip", "user_agent", "MY_DEVICE" + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_connections.assert_called_once() + + args = writer.write_connections.call_args[0] + self.assertEqual(len(args[0]), 1) + self.assertEqual(args[0][0]["ip"], "ip") + self.assertEqual(args[0][0]["user_agent"], "user_agent") + self.assertGreater(args[0][0]["last_seen"], 0) + self.assertNotIn("access_token", args[0][0]) -- cgit 1.5.1 From 230a831c734246aa4db7bd842947c7ea277ca126 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Feb 2023 15:45:10 -0500 Subject: Attempt to delete more duplicate rows in receipts_linearized table. (#14915) The previous assumption was that the stream_id column was unique (for a room ID, receipt type, user ID tuple), but this turned out to be incorrect. Now find the max stream ID, then map this back to a database-specific row identifier and delete other rows which match the (room ID, receipt type, user ID) tuple, but *not* the row ID. --- changelog.d/14915.bugfix | 1 + synapse/storage/databases/main/receipts.py | 34 ++++++++++++++++++++------- tests/storage/databases/main/test_receipts.py | 4 +++- 3 files changed, 30 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14915.bugfix (limited to 'synapse') diff --git a/changelog.d/14915.bugfix b/changelog.d/14915.bugfix new file mode 100644 index 0000000000..4969e5450c --- /dev/null +++ b/changelog.d/14915.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where the background updates to add non-thread unique indexes on receipts could fail when upgrading from 1.67.0 or earlier. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3468f354e6..29972d5204 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -941,10 +941,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): receipts.""" def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + if isinstance(self.database_engine, PostgresEngine): + ROW_ID_NAME = "ctid" + else: + ROW_ID_NAME = "rowid" + # Identify any duplicate receipts arising from # https://github.com/matrix-org/synapse/issues/14406. - # We expect the following query to use the per-thread receipt index and take - # less than a minute. + # The following query takes less than a minute on matrix.org. sql = """ SELECT MAX(stream_id), room_id, receipt_type, user_id FROM receipts_linearized @@ -956,19 +960,33 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) # Then remove duplicate receipts, keeping the one with the highest - # `stream_id`. There should only be a single receipt with any given - # `stream_id`. - for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: - sql = """ + # `stream_id`. Since there might be duplicate rows with the same + # `stream_id`, we delete by the ctid instead. + for stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = f""" + SELECT {ROW_ID_NAME} + FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id = ? + LIMIT 1 + """ + txn.execute(sql, (room_id, receipt_type, user_id, stream_id)) + row_id = cast(Tuple[str], txn.fetchone())[0] + + sql = f""" DELETE FROM receipts_linearized WHERE room_id = ? AND receipt_type = ? AND user_id = ? AND thread_id IS NULL AND - stream_id < ? + {ROW_ID_NAME} != ? """ - txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + txn.execute(sql, (room_id, receipt_type, user_id, row_id)) await self.db_pool.runInteraction( self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index 68026e2830..ac77aec003 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -168,7 +168,9 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): {"stream_id": 6, "event_id": "$some_event"}, ], (self.other_room_id, "m.read", self.user_id): [ - {"stream_id": 7, "event_id": "$some_event"} + # It is possible for stream IDs to be duplicated. + {"stream_id": 7, "event_id": "$some_event"}, + {"stream_id": 7, "event_id": "$some_event"}, ], }, expected_unique_receipts={ -- cgit 1.5.1 From 1182ae50635db94d3c9c47990a0befcbf6306b62 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Feb 2023 16:35:24 -0500 Subject: Add helper to parse an enum from query args & use it. (#14956) The `parse_enum` helper pulls an enum value from the query string (by delegating down to the parse_string helper with values generated from the enum). This is used to pull out "f" and "b" in most places and then we thread the resulting Direction enum throughout more code. --- changelog.d/14956.misc | 1 + synapse/federation/federation_client.py | 15 +++-- synapse/federation/federation_server.py | 12 +++- synapse/federation/transport/client.py | 8 +-- synapse/federation/transport/server/federation.py | 7 ++- synapse/handlers/account_data.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/room.py | 9 +-- synapse/http/servlet.py | 70 ++++++++++++++++++++++ synapse/rest/admin/event_reports.py | 12 +--- synapse/rest/admin/federation.py | 7 ++- synapse/rest/admin/media.py | 21 ++++--- synapse/rest/admin/rooms.py | 16 ++--- synapse/rest/admin/statistics.py | 11 +--- synapse/rest/admin/users.py | 5 +- synapse/rest/client/relations.py | 3 +- synapse/rest/client/room.py | 5 +- synapse/storage/databases/main/__init__.py | 5 +- synapse/storage/databases/main/events_worker.py | 11 ++-- synapse/storage/databases/main/media_repository.py | 5 +- synapse/storage/databases/main/room.py | 9 +-- synapse/storage/databases/main/stats.py | 6 +- synapse/storage/databases/main/transactions.py | 13 ++-- synapse/streams/config.py | 12 +--- tests/rest/admin/test_event_reports.py | 5 +- 25 files changed, 176 insertions(+), 96 deletions(-) create mode 100644 changelog.d/14956.misc (limited to 'synapse') diff --git a/changelog.d/14956.misc b/changelog.d/14956.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14956.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index feb32e40e5..8493ffc2e5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -37,7 +37,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership from synapse.api.errors import ( CodeMessageException, Codes, @@ -1680,7 +1680,12 @@ class FederationClient(FederationBase): return result async def timestamp_to_event( - self, *, destinations: List[str], room_id: str, timestamp: int, direction: str + self, + *, + destinations: List[str], + room_id: str, + timestamp: int, + direction: Direction, ) -> Optional["TimestampToEventResponse"]: """ Calls each remote federating server from `destinations` asking for their closest @@ -1693,7 +1698,7 @@ class FederationClient(FederationBase): room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -1738,7 +1743,7 @@ class FederationClient(FederationBase): return None async def _timestamp_to_event_from_destination( - self, destination: str, room_id: str, timestamp: int, direction: str + self, destination: str, room_id: str, timestamp: int, direction: Direction ) -> "TimestampToEventResponse": """ Calls a remote federating server at `destination` asking for their @@ -1751,7 +1756,7 @@ class FederationClient(FederationBase): room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c9a6dfd1a4..8d36172484 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,13 @@ from prometheus_client import Counter, Gauge, Histogram from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + Direction, + EduTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -218,7 +224,7 @@ class FederationServer(FederationBase): return 200, res async def on_timestamp_to_event_request( - self, origin: str, room_id: str, timestamp: int, direction: str + self, origin: str, room_id: str, timestamp: int, direction: Direction ) -> Tuple[int, Dict[str, Any]]: """When we receive a federated `/timestamp_to_event` request, handle all of the logic for validating and fetching the event. @@ -228,7 +234,7 @@ class FederationServer(FederationBase): room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 682666ab36..c05d598b70 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -32,7 +32,7 @@ from typing import ( import attr import ijson -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.room_versions import RoomVersion from synapse.api.urls import ( @@ -169,7 +169,7 @@ class TransportLayerClient: ) async def timestamp_to_event( - self, destination: str, room_id: str, timestamp: int, direction: str + self, destination: str, room_id: str, timestamp: int, direction: Direction ) -> Union[JsonDict, List]: """ Calls a remote federating server at `destination` asking for their @@ -180,7 +180,7 @@ class TransportLayerClient: room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -194,7 +194,7 @@ class TransportLayerClient: room_id, ) - args = {"ts": [str(timestamp)], "dir": [direction]} + args = {"ts": [str(timestamp)], "dir": [direction.value]} remote_response = await self.client.get_json( destination, path=path, args=args, try_trailing_slash_on_400=True diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 17c427387e..f7ca87adc4 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -26,7 +26,7 @@ from typing import ( from typing_extensions import Literal -from synapse.api.constants import EduTypes +from synapse.api.constants import Direction, EduTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX @@ -234,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet): room_id: str, ) -> Tuple[int, JsonDict]: timestamp = parse_integer_from_args(query, "ts", required=True) - direction = parse_string_from_args( - query, "dir", default="f", allowed_values=["f", "b"], required=True + direction_str = parse_string_from_args( + query, "dir", allowed_values=["f", "b"], required=True ) + direction = Direction(direction_str) return await self.handler.on_timestamp_to_event_request( origin, room_id, timestamp, direction diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index d500b21809..67e789eef7 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - def get_current_key(self, direction: str = "f") -> int: + def get_current_key(self) -> int: return self.store.get_max_account_data_stream_id() async def get_new_events( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6a4fed1156..04c61ae3dd 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -315,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]): return events, to_key - def get_current_key(self, direction: str = "f") -> int: + def get_current_key(self) -> int: return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 60a6d9cf3c..7ba7c4ff07 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -27,6 +27,7 @@ from typing_extensions import TypedDict import synapse.events.snapshot from synapse.api.constants import ( + Direction, EventContentFields, EventTypes, GuestAccess, @@ -1487,7 +1488,7 @@ class TimestampLookupHandler: requester: Requester, room_id: str, timestamp: int, - direction: str, + direction: Direction, ) -> Tuple[str, int]: """Find the closest event to the given timestamp in the given direction. If we can't find an event locally or the event we have locally is next to a gap, @@ -1498,7 +1499,7 @@ class TimestampLookupHandler: room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: @@ -1533,13 +1534,13 @@ class TimestampLookupHandler: local_event_id, allow_none=False, allow_rejected=False ) - if direction == "f": + if direction == Direction.FORWARDS: # We only need to check for a backward gap if we're looking forwards # to ensure there is nothing in between. is_event_next_to_backward_gap = ( await self.store.is_event_next_to_backward_gap(local_event) ) - elif direction == "b": + elif direction == Direction.BACKWARDS: # We only need to check for a forward gap if we're looking backwards # to ensure there is nothing in between is_event_next_to_forward_gap = ( diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index dead02cd5c..0070bd2940 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -13,6 +13,7 @@ # limitations under the License. """ This module contains base REST classes for constructing REST servlets. """ +import enum import logging from http import HTTPStatus from typing import ( @@ -362,6 +363,7 @@ def parse_string( request: Request, name: str, *, + default: Optional[str] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -413,6 +415,74 @@ def parse_string( ) +EnumT = TypeVar("EnumT", bound=enum.Enum) + + +@overload +def parse_enum( + request: Request, + name: str, + E: Type[EnumT], + default: EnumT, +) -> EnumT: + ... + + +@overload +def parse_enum( + request: Request, + name: str, + E: Type[EnumT], + *, + required: Literal[True], +) -> EnumT: + ... + + +def parse_enum( + request: Request, + name: str, + E: Type[EnumT], + default: Optional[EnumT] = None, + required: bool = False, +) -> Optional[EnumT]: + """ + Parse an enum parameter from the request query string. + + Note that the enum *must only have string values*. + + Args: + request: the twisted HTTP request. + name: the name of the query parameter. + E: the enum which represents valid values + default: enum value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + An enum value. + + Raises: + SynapseError if the parameter is absent and required, or if the + parameter is present, must be one of a list of allowed values and + is not one of those allowed values. + """ + # Assert the enum values are strings. + assert all( + isinstance(e.value, str) for e in E + ), "parse_enum only works with string values" + str_value = parse_string( + request, + name, + default=default.value if default is not None else None, + required=required, + allowed_values=[e.value for e in E], + ) + if str_value is None: + return None + return E(str_value) + + def _parse_string_value( value: bytes, allowed_values: Optional[Iterable[str]], diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 6d634eef70..a3beb74e2c 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -16,8 +16,9 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict @@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet): start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - direction = parse_string(request, "dir", default="b") + direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS) user_id = parse_string(request, "user_id") room_id = parse_string(request, "room_id") @@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - if direction not in ("f", "b"): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, - ) - event_reports, total = await self.store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 023ed92144..e0ee55bd0e 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -15,9 +15,10 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.federation.transport.server import Authenticator -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.storage.databases.main.transactions import DestinationSortOrder @@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet): allowed_values=[dest.value for dest in DestinationSortOrder], ) - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) destinations, total = await self._store.get_destinations_paginate( start, limit, destination, order_by, direction @@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) rooms, total = await self._store.get_destination_rooms_paginate( destination, start, limit, direction diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 73470f09ae..0d072c42a7 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,9 +17,16 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_boolean, + parse_enum, + parse_integer, + parse_string, +) from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, @@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet): # to newest media is on top for backward compatibility. if b"order_by" not in request.args and b"dir" not in request.args: order_by = MediaSortOrder.CREATED_TS.value - direction = "b" + direction = Direction.BACKWARDS else: order_by = parse_string( request, @@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet): default=MediaSortOrder.CREATED_TS.value, allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") + direction = parse_enum( + request, "dir", Direction, default=Direction.FORWARDS ) media, total = await self.store.get_local_media_by_user_paginate( @@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet): # to newest media is on top for backward compatibility. if b"order_by" not in request.args and b"dir" not in request.args: order_by = MediaSortOrder.CREATED_TS.value - direction = "b" + direction = Direction.BACKWARDS else: order_by = parse_string( request, @@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet): default=MediaSortOrder.CREATED_TS.value, allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") + direction = parse_enum( + request, "dir", Direction, default=Direction.FORWARDS ) media, _ = await self.store.get_local_media_by_user_paginate( diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index e957aa28ca..1d6e4982d7 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -16,13 +16,14 @@ from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, assert_params_in_dict, + parse_enum, parse_integer, parse_json_object_from_request, parse_string, @@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - direction = parse_string(request, "dir", default="f") - if direction not in ("f", "b"): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, - ) - - reverse_order = True if direction == "b" else False + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) + reverse_order = True if direction == Direction.BACKWARDS else False # Return list of rooms according to parameters rooms, total_rooms = await self.store.get_rooms_paginate( @@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet): await assert_user_is_admin(self._auth, requester) timestamp = parse_integer(request, "ts", required=True) - direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) ( event_id, diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 3b142b8402..9c45f4650d 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -16,8 +16,9 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import Direction from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.storage.databases.main.stats import UserSortOrder @@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - direction = parse_string(request, "dir", default="f") - if direction not in ("f", "b"): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, - ) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) users_media, total = await self.store.get_users_media_usage_paginate( start, limit, from_ts, until_ts, order_by, direction, search_term diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 0841b89c1a..b9dca8ef3a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -18,12 +18,13 @@ import secrets from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from synapse.api.constants import UserTypes +from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_boolean, + parse_enum, parse_integer, parse_json_object_from_request, parse_string, @@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet): ), ) - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) users, total = await self.store.get_users_paginate( start, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 9dd59196d9..7456d6f507 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -16,6 +16,7 @@ import logging import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.api.constants import Direction from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request( - self._store, request, default_limit=5, default_dir="b" + self._store, request, default_limit=5, default_dir=Direction.BACKWARDS ) # The unstable version of this API returns an extra field for client diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 790614d721..d0db85cca7 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -26,7 +26,7 @@ from prometheus_client.core import Histogram from twisted.web.server import Request from synapse import event_auth -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -44,6 +44,7 @@ from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_boolean, + parse_enum, parse_integer, parse_json_object_from_request, parse_string, @@ -1297,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet): await self._auth.check_user_in_room_or_world_readable(room_id, requester) timestamp = parse_integer(request, "ts", required=True) - direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) ( event_id, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0e47592be3..837dc7646e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast +from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import ( DatabasePool, @@ -167,7 +168,7 @@ class DataStore( guests: bool = True, deactivated: bool = False, order_by: str = UserSortOrder.NAME.value, - direction: str = "f", + direction: Direction = Direction.FORWARDS, approved: bool = True, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from @@ -197,7 +198,7 @@ class DataStore( # Set ordering order_by_column = UserSortOrder(order_by).value - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f42af34a2f..d7d08369ca 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -38,7 +38,7 @@ from typing_extensions import Literal from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import Direction, EventTypes from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, @@ -2240,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_event_id_for_timestamp( - self, room_id: str, timestamp: int, direction: str + self, room_id: str, timestamp: int, direction: Direction ) -> Optional[str]: """Find the closest event to the given timestamp in the given direction. @@ -2248,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore): room_id: Room to fetch the event from timestamp: The point in time (inclusive) we should navigate from in the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward + direction: indicates whether we should navigate forward or backward from the given timestamp to find the closest event. Returns: The closest event_id otherwise None if we can't find any event in the given direction. """ - if direction == "b": + if direction == Direction.BACKWARDS: # Find closest event *before* a given timestamp. We use descending # (which gives values largest to smallest) because we want the # largest possible timestamp *before* the given timestamp. @@ -2307,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore): return None - if direction not in ("f", "b"): - raise ValueError("Unknown direction: %s" % (direction,)) - return await self.db_pool.runInteraction( "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 9b172a64d8..b202c5eb87 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -26,6 +26,7 @@ from typing import ( cast, ) +from synapse.api.constants import Direction from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): limit: int, user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, - direction: str = "f", + direction: Direction = Direction.FORWARDS, ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): # Set ordering order_by_column = MediaSortOrder(order_by).value - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index fbbc018887..4ddb27f686 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -35,6 +35,7 @@ from typing import ( import attr from synapse.api.constants import ( + Direction, EventContentFields, EventTypes, JoinRules, @@ -2204,7 +2205,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self, start: int, limit: int, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, user_id: Optional[str] = None, room_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: @@ -2213,8 +2214,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): Args: start: event offset to begin the query from limit: number of rows to retrieve - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`) + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: @@ -2236,7 +2237,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): filters.append("er.room_id LIKE ?") args.extend(["%" + room_id + "%"]) - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 0c1cbd540d..d7b7d0c3c9 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -22,7 +22,7 @@ from typing_extensions import Counter from twisted.internet.defer import DeferredLock -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership from synapse.api.errors import StoreError from synapse.storage.database import ( DatabasePool, @@ -663,7 +663,7 @@ class StatsStore(StateDeltasStore): from_ts: Optional[int] = None, until_ts: Optional[int] = None, order_by: Optional[str] = UserSortOrder.USER_ID.value, - direction: Optional[str] = "f", + direction: Direction = Direction.FORWARDS, search_term: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users and their uploaded local media @@ -714,7 +714,7 @@ class StatsStore(StateDeltasStore): 500, "Incorrect value for order_by provided: %s" % order_by ) - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index f8c6877ee8..6b33d809b6 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json +from synapse.api.constants import Direction from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): limit: int, destination: Optional[str] = None, order_by: str = DestinationSortOrder.DESTINATION.value, - direction: str = "f", + direction: Direction = Direction.FORWARDS, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of destinations. This will return a json list of destinations and the @@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) -> Tuple[List[JsonDict], int]: order_by_column = DestinationSortOrder(order_by).value - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" @@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) async def get_destination_rooms_paginate( - self, destination: str, start: int, limit: int, direction: str = "f" + self, + destination: str, + start: int, + limit: int, + direction: Direction = Direction.FORWARDS, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of destination's rooms. This will return a json list of rooms and the @@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 5cb7875181..a044280410 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -18,7 +18,7 @@ import attr from synapse.api.constants import Direction from synapse.api.errors import SynapseError -from synapse.http.servlet import parse_integer, parse_string +from synapse.http.servlet import parse_enum, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.storage.databases.main import DataStore from synapse.types import StreamToken @@ -44,15 +44,9 @@ class PaginationConfig: store: "DataStore", request: SynapseRequest, default_limit: int, - default_dir: str = "f", + default_dir: Direction = Direction.FORWARDS, ) -> "PaginationConfig": - direction_str = parse_string( - request, - "dir", - default=default_dir, - allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value], - ) - direction = Direction(direction_str) + direction = parse_enum(request, "dir", Direction, default=default_dir) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 8a4e5c3f77..233eba3516 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -280,7 +280,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - self.assertEqual("Unknown direction: bar", channel.json_body["error"]) + self.assertEqual( + "Query parameter 'dir' must be one of ['b', 'f']", + channel.json_body["error"], + ) def test_limit_is_negative(self) -> None: """ -- cgit 1.5.1 From 58214dbb9b8a85c0dafc65162e9c20ee1885ce4e Mon Sep 17 00:00:00 2001 From: realtyem Date: Wed, 1 Feb 2023 17:42:45 -0600 Subject: Allow enabling the asyncio reactor in complement (#14858) Signed-off-by: Jason Little realtyem@gmail.com --- .github/workflows/tests.yml | 5 ++++- changelog.d/14858.misc | 1 + docker/complement/conf/start_for_complement.sh | 13 ++++++++++++- docs/development/contributing_guide.md | 1 + scripts-dev/complement.sh | 5 +++++ synapse/app/complement_fork_starter.py | 21 +++++++++++++++++++-- 6 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14858.misc (limited to 'synapse') diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f184727ced..6561b490bc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -541,8 +541,11 @@ jobs: - run: | set -o pipefail - POSTGRES=${{ (matrix.database == 'Postgres') && 1 || '' }} WORKERS=${{ (matrix.arrangement == 'workers') && 1 || '' }} COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | synapse/.ci/scripts/gotestfmt + COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | synapse/.ci/scripts/gotestfmt shell: bash + env: + POSTGRES: ${{ (matrix.database == 'Postgres') && 1 || '' }} + WORKERS: ${{ (matrix.arrangement == 'workers') && 1 || '' }} name: Run Complement Tests cargo-test: diff --git a/changelog.d/14858.misc b/changelog.d/14858.misc new file mode 100644 index 0000000000..c48f40cd38 --- /dev/null +++ b/changelog.d/14858.misc @@ -0,0 +1 @@ +Run the integration test suites with the asyncio reactor enabled in CI. diff --git a/docker/complement/conf/start_for_complement.sh b/docker/complement/conf/start_for_complement.sh index 49d79745b0..af13209c54 100755 --- a/docker/complement/conf/start_for_complement.sh +++ b/docker/complement/conf/start_for_complement.sh @@ -6,7 +6,7 @@ set -e echo "Complement Synapse launcher" echo " Args: $@" -echo " Env: SYNAPSE_COMPLEMENT_DATABASE=$SYNAPSE_COMPLEMENT_DATABASE SYNAPSE_COMPLEMENT_USE_WORKERS=$SYNAPSE_COMPLEMENT_USE_WORKERS" +echo " Env: SYNAPSE_COMPLEMENT_DATABASE=$SYNAPSE_COMPLEMENT_DATABASE SYNAPSE_COMPLEMENT_USE_WORKERS=$SYNAPSE_COMPLEMENT_USE_WORKERS SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR=$SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR" function log { d=$(date +"%Y-%m-%d %H:%M:%S,%3N") @@ -76,6 +76,17 @@ else fi +if [[ -n "$SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR" ]]; then + if [[ -n "$SYNAPSE_USE_EXPERIMENTAL_FORKING_LAUNCHER" ]]; then + export SYNAPSE_COMPLEMENT_FORKING_LAUNCHER_ASYNC_IO_REACTOR="1" + else + export SYNAPSE_ASYNC_IO_REACTOR="1" + fi +else + export SYNAPSE_ASYNC_IO_REACTOR="0" +fi + + # Add Complement's appservice registration directory, if there is one # (It can be absent when there are no application services in this test!) if [ -d /complement/appservice ]; then diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index 3cbfe96987..36bc884684 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -332,6 +332,7 @@ The above will run a monolithic (single-process) Synapse with SQLite as the data [here](https://github.com/matrix-org/synapse/blob/develop/docker/configure_workers_and_start.py#L54). A safe example would be `WORKER_TYPES="federation_inbound, federation_sender, synchrotron"`. See the [worker documentation](../workers.md) for additional information on workers. +- Passing `ASYNCIO_REACTOR=1` as an environment variable to use the Twisted asyncio reactor instead of the default one. To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`, e.g: ```sh diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e72d96fd16..66aaa3d848 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -228,6 +228,11 @@ else test_tags="$test_tags,msc2716" fi +if [[ -n "$ASYNCIO_REACTOR" ]]; then + # Enable the Twisted asyncio reactor + export PASS_SYNAPSE_COMPLEMENT_USE_ASYNCIO_REACTOR=true +fi + if [[ -n "$SYNAPSE_TEST_LOG_LEVEL" ]]; then # Set the log level to what is desired diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 8c0f4a57e7..920538f44d 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -110,6 +110,8 @@ def _worker_entrypoint( and then kick off the worker's main() function. """ + from synapse.util.stringutils import strtobool + sys.argv = args # reset the custom signal handlers that we installed, so that the children start @@ -117,9 +119,24 @@ def _worker_entrypoint( for sig, handler in _original_signal_handlers.items(): signal.signal(sig, handler) - from twisted.internet.epollreactor import EPollReactor + # Install the asyncio reactor if the + # SYNAPSE_COMPLEMENT_FORKING_LAUNCHER_ASYNC_IO_REACTOR is set to 1. The + # SYNAPSE_ASYNC_IO_REACTOR variable would be used, but then causes + # synapse/__init__.py to also try to install an asyncio reactor. + if strtobool( + os.environ.get("SYNAPSE_COMPLEMENT_FORKING_LAUNCHER_ASYNC_IO_REACTOR", "0") + ): + import asyncio + + from twisted.internet.asyncioreactor import AsyncioSelectorReactor + + reactor = AsyncioSelectorReactor(asyncio.get_event_loop()) + proxy_reactor._install_real_reactor(reactor) + else: + from twisted.internet.epollreactor import EPollReactor + + proxy_reactor._install_real_reactor(EPollReactor()) - proxy_reactor._install_real_reactor(EPollReactor()) func() -- cgit 1.5.1 From 2186ebed6c9dfe15cfa8a8a7c97c2a89a907f9a8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 2 Feb 2023 16:49:14 +0000 Subject: Fetch fewer events when getting hosts in room (#14962) --- changelog.d/14962.feature | 1 + synapse/storage/databases/main/roommember.py | 46 ++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14962.feature (limited to 'synapse') diff --git a/changelog.d/14962.feature b/changelog.d/14962.feature new file mode 100644 index 0000000000..38f26012f2 --- /dev/null +++ b/changelog.d/14962.feature @@ -0,0 +1 @@ +Improve performance when joining or sending an event large rooms. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 8e2ba7b7b4..ea6a5e2f34 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from itertools import chain from typing import ( TYPE_CHECKING, AbstractSet, @@ -1131,12 +1132,33 @@ class RoomMemberWorkerStore(EventsWorkerStore): else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. + # + # We need to fetch all hosts joined to the room according to `state` by + # inspecting all join memberships in `state`. However, if the `state` is + # relatively recent then many of its events are likely to be held in + # the current state of the room, which is easily available and likely + # cached. + # + # We therefore compute the set of `state` events not in the + # current state and only fetch those. + current_memberships = ( + await self._get_approximate_current_memberships_in_room(room_id) + ) + unknown_state_events = {} + joined_users_in_current_state = [] + + for (type, state_key), event_id in state.items(): + if event_id not in current_memberships: + unknown_state_events[type, state_key] = event_id + elif current_memberships[event_id] == Membership.JOIN: + joined_users_in_current_state.append(state_key) + joined_user_ids = await self.get_joined_user_ids_from_state( - room_id, state + room_id, unknown_state_events ) cache.hosts_to_joined_users = {} - for user_id in joined_user_ids: + for user_id in chain(joined_user_ids, joined_users_in_current_state): host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) @@ -1147,6 +1169,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): return frozenset(cache.hosts_to_joined_users) + async def _get_approximate_current_memberships_in_room( + self, room_id: str + ) -> Mapping[str, Optional[str]]: + """Build a map from event id to membership, for all events in the current state. + + The event ids of non-memberships events (e.g. `m.room.power_levels`) are present + in the result, mapped to values of `None`. + + The result is approximate for partially-joined rooms. It is fully accurate + for fully-joined rooms. + """ + + rows = await self.db_pool.simple_select_list( + "current_state_events", + keyvalues={"room_id": room_id}, + retcols=("event_id", "membership"), + desc="has_completed_background_updates", + ) + return {row["event_id"]: row["membership"] for row in rows} + @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache() -- cgit 1.5.1 From f36da501be4287e723a0a53ac4568d836676a15d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 2 Feb 2023 11:58:20 -0500 Subject: Do not calculate presence or ephemeral events when they are filtered out (#14970) This expands the previous optimisation from being only for initial sync to being for all sync requests. It also inverts some of the logic to be inclusive instead of exclusive. --- changelog.d/14970.misc | 1 + synapse/handlers/sync.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) create mode 100644 changelog.d/14970.misc (limited to 'synapse') diff --git a/changelog.d/14970.misc b/changelog.d/14970.misc new file mode 100644 index 0000000000..3657623602 --- /dev/null +++ b/changelog.d/14970.misc @@ -0,0 +1 @@ +Improve performance of `/sync` in a few situations. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5235e29460..0cb8d5ef4b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1459,10 +1459,12 @@ class SyncHandler: sync_result_builder, account_data_by_room ) - block_all_presence_data = ( - since_token is None and sync_config.filter_collection.blocks_all_presence() + # Presence data is included if the server has it enabled and not filtered out. + include_presence_data = ( + self.hs_config.server.use_presence + and not sync_config.filter_collection.blocks_all_presence() ) - if self.hs_config.server.use_presence and not block_all_presence_data: + if include_presence_data: logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( sync_result_builder, @@ -1841,15 +1843,12 @@ class SyncHandler: """ since_token = sync_result_builder.since_token - - # 1. Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() - block_all_room_ephemeral = ( - since_token is None - and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() - ) - if block_all_room_ephemeral: + # 1. Start by fetching all ephemeral events in rooms we've joined (if required). + if ( + sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + ): ephemeral_by_room: Dict[str, List[JsonDict]] = {} else: now_token, ephemeral_by_room = await self.ephemeral_by_room( -- cgit 1.5.1 From da05b70af5bf84825332b2ac0d63c6deda4b376f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 2 Feb 2023 13:45:12 -0500 Subject: Skip unused calculations in sync handler. (#14908) If a sync request does not need to calculate per-room entries & is not generating presence & is not generating device list data (e.g. during initial sync) avoid the expensive calculation of room specific data. This is a micro-optimisation for clients syncing simply to receive to-device information. --- changelog.d/14908.misc | 1 + synapse/api/filtering.py | 3 + synapse/handlers/sync.py | 258 ++++++++++++++++++++++++----------------------- 3 files changed, 137 insertions(+), 125 deletions(-) create mode 100644 changelog.d/14908.misc (limited to 'synapse') diff --git a/changelog.d/14908.misc b/changelog.d/14908.misc new file mode 100644 index 0000000000..3657623602 --- /dev/null +++ b/changelog.d/14908.misc @@ -0,0 +1 @@ +Improve performance of `/sync` in a few situations. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4cf8f0cc8e..2b5af264b4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -283,6 +283,9 @@ class FilterCollection: await self._room_filter.filter(events) ) + def blocks_all_rooms(self) -> bool: + return self._room_filter.filters_all_rooms() + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0cb8d5ef4b..3566537894 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1448,41 +1448,67 @@ class SyncHandler: sync_result_builder ) - logger.debug("Fetching room data") - - ( - newly_joined_rooms, - newly_joined_or_invited_or_knocked_users, - newly_left_rooms, - newly_left_users, - ) = await self._generate_sync_entry_for_rooms( - sync_result_builder, account_data_by_room - ) - # Presence data is included if the server has it enabled and not filtered out. - include_presence_data = ( + include_presence_data = bool( self.hs_config.server.use_presence and not sync_config.filter_collection.blocks_all_presence() ) - if include_presence_data: - logger.debug("Fetching presence data") - await self._generate_sync_entry_for_presence( - sync_result_builder, + # Device list updates are sent if a since token is provided. + include_device_list_updates = bool(since_token and since_token.device_list_key) + + # If we do not care about the rooms or things which depend on the room + # data (namely presence and device list updates), then we can skip + # this process completely. + device_lists = DeviceListUpdates() + if ( + not sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + or include_presence_data + or include_device_list_updates + ): + logger.debug("Fetching room data") + + # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which + # is used in calculate_user_changes below. + ( newly_joined_rooms, - newly_joined_or_invited_or_knocked_users, + newly_left_rooms, + ) = await self._generate_sync_entry_for_rooms( + sync_result_builder, account_data_by_room ) + # Work out which users have joined or left rooms we're in. We use this + # to build the presence and device_list parts of the sync response in + # `_generate_sync_entry_for_presence` and + # `_generate_sync_entry_for_device_list` respectively. + if include_presence_data or include_device_list_updates: + # This uses the sync_result_builder.joined which is set in + # `_generate_sync_entry_for_rooms`, if that didn't find any joined + # rooms for some reason it is a no-op. + ( + newly_joined_or_invited_or_knocked_users, + newly_left_users, + ) = sync_result_builder.calculate_user_changes() + + if include_presence_data: + logger.debug("Fetching presence data") + await self._generate_sync_entry_for_presence( + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, + ) + + if include_device_list_updates: + device_lists = await self._generate_sync_entry_for_device_list( + sync_result_builder, + newly_joined_rooms=newly_joined_rooms, + newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, + newly_left_rooms=newly_left_rooms, + newly_left_users=newly_left_users, + ) + logger.debug("Fetching to-device data") await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = await self._generate_sync_entry_for_device_list( - sync_result_builder, - newly_joined_rooms=newly_joined_rooms, - newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, - newly_left_rooms=newly_left_rooms, - newly_left_users=newly_left_users, - ) - logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_keys_count: JsonDict = {} @@ -1551,6 +1577,7 @@ class SyncHandler: user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token + assert since_token is not None # Take a copy since these fields will be mutated later. newly_joined_or_invited_or_knocked_users = set( @@ -1558,92 +1585,85 @@ class SyncHandler: ) newly_left_users = set(newly_left_users) - if since_token and since_token.device_list_key: - # We want to figure out what user IDs the client should refetch - # device keys for, and which users we aren't going to track changes - # for anymore. - # - # For the first step we check: - # a. if any users we share a room with have updated their devices, - # and - # b. we also check if we've joined any new rooms, or if a user has - # joined a room we're in. - # - # For the second step we just find any users we no longer share a - # room with by looking at all users that have left a room plus users - # that were in a room we've left. + # We want to figure out what user IDs the client should refetch + # device keys for, and which users we aren't going to track changes + # for anymore. + # + # For the first step we check: + # a. if any users we share a room with have updated their devices, + # and + # b. we also check if we've joined any new rooms, or if a user has + # joined a room we're in. + # + # For the second step we just find any users we no longer share a + # room with by looking at all users that have left a room plus users + # that were in a room we've left. - users_that_have_changed = set() + users_that_have_changed = set() - joined_rooms = sync_result_builder.joined_room_ids + joined_rooms = sync_result_builder.joined_room_ids - # Step 1a, check for changes in devices of users we share a room - # with - # - # We do this in two different ways depending on what we have cached. - # If we already have a list of all the user that have changed since - # the last sync then it's likely more efficient to compare the rooms - # they're in with the rooms the syncing user is in. - # - # If we don't have that info cached then we get all the users that - # share a room with our user and check if those users have changed. - cache_result = self.store.get_cached_device_list_changes( - since_token.device_list_key - ) - if cache_result.hit: - changed_users = cache_result.entities - - result = await self.store.get_rooms_for_users(changed_users) - - for changed_user_id, entries in result.items(): - # Check if the changed user shares any rooms with the user, - # or if the changed user is the syncing user (as we always - # want to include device list updates of their own devices). - if user_id == changed_user_id or any( - rid in joined_rooms for rid in entries - ): - users_that_have_changed.add(changed_user_id) - else: - users_that_have_changed = ( - await self._device_handler.get_device_changes_in_shared_rooms( - user_id, - sync_result_builder.joined_room_ids, - from_token=since_token, - ) - ) - - # Step 1b, check for newly joined rooms - for room_id in newly_joined_rooms: - joined_users = await self.store.get_users_in_room(room_id) - newly_joined_or_invited_or_knocked_users.update(joined_users) + # Step 1a, check for changes in devices of users we share a room + # with + # + # We do this in two different ways depending on what we have cached. + # If we already have a list of all the user that have changed since + # the last sync then it's likely more efficient to compare the rooms + # they're in with the rooms the syncing user is in. + # + # If we don't have that info cached then we get all the users that + # share a room with our user and check if those users have changed. + cache_result = self.store.get_cached_device_list_changes( + since_token.device_list_key + ) + if cache_result.hit: + changed_users = cache_result.entities - # TODO: Check that these users are actually new, i.e. either they - # weren't in the previous sync *or* they left and rejoined. - users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) + result = await self.store.get_rooms_for_users(changed_users) - user_signatures_changed = ( - await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key + for changed_user_id, entries in result.items(): + # Check if the changed user shares any rooms with the user, + # or if the changed user is the syncing user (as we always + # want to include device list updates of their own devices). + if user_id == changed_user_id or any( + rid in joined_rooms for rid in entries + ): + users_that_have_changed.add(changed_user_id) + else: + users_that_have_changed = ( + await self._device_handler.get_device_changes_in_shared_rooms( + user_id, + sync_result_builder.joined_room_ids, + from_token=since_token, ) ) - users_that_have_changed.update(user_signatures_changed) - # Now find users that we no longer track - for room_id in newly_left_rooms: - left_users = await self.store.get_users_in_room(room_id) - newly_left_users.update(left_users) + # Step 1b, check for newly joined rooms + for room_id in newly_joined_rooms: + joined_users = await self.store.get_users_in_room(room_id) + newly_joined_or_invited_or_knocked_users.update(joined_users) - # Remove any users that we still share a room with. - left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) - for user_id, entries in left_users_rooms.items(): - if any(rid in joined_rooms for rid in entries): - newly_left_users.discard(user_id) + # TODO: Check that these users are actually new, i.e. either they + # weren't in the previous sync *or* they left and rejoined. + users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) - return DeviceListUpdates( - changed=users_that_have_changed, left=newly_left_users - ) - else: - return DeviceListUpdates() + user_signatures_changed = await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) + users_that_have_changed.update(user_signatures_changed) + + # Now find users that we no longer track + for room_id in newly_left_rooms: + left_users = await self.store.get_users_in_room(room_id) + newly_left_users.update(left_users) + + # Remove any users that we still share a room with. + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) + for user_id, entries in left_users_rooms.items(): + if any(rid in joined_rooms for rid in entries): + newly_left_users.discard(user_id) + + return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users) @trace async def _generate_sync_entry_for_to_device( @@ -1720,6 +1740,7 @@ class SyncHandler: since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: + # TODO Do not fetch room account data if it will be unused. ( global_account_data, account_data_by_room, @@ -1736,6 +1757,7 @@ class SyncHandler: sync_config.user ) else: + # TODO Do not fetch room account data if it will be unused. ( global_account_data, account_data_by_room, @@ -1818,7 +1840,7 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", account_data_by_room: Dict[str, Dict[str, JsonDict]], - ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]: + ) -> Tuple[AbstractSet[str], AbstractSet[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1831,24 +1853,22 @@ class SyncHandler: account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple describing rooms the user has joined or left, and users who've - joined or left rooms any rooms the user is in. This gets used later in - `_generate_sync_entry_for_device_list`. + Returns a 2-tuple describing rooms the user has joined or left. Its entries are: - newly_joined_rooms - - newly_joined_or_invited_or_knocked_users - newly_left_rooms - - newly_left_users """ since_token = sync_result_builder.since_token user_id = sync_result_builder.sync_config.user.to_string() # 1. Start by fetching all ephemeral events in rooms we've joined (if required). - if ( - sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() - ): + block_all_room_ephemeral = ( + sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + ) + if block_all_room_ephemeral: ephemeral_by_room: Dict[str, List[JsonDict]] = {} else: now_token, ephemeral_by_room = await self.ephemeral_by_room( @@ -1870,7 +1890,7 @@ class SyncHandler: ) if not tags_by_room: logger.debug("no-oping sync") - return set(), set(), set(), set() + return set(), set() # 3. Work out which rooms need reporting in the sync response. ignored_users = await self.store.ignored_users(user_id) @@ -1899,6 +1919,7 @@ class SyncHandler: # joined or archived). async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) + # Note that this mutates sync_result_builder.{joined,archived}. await self._generate_room_entry( sync_result_builder, room_entry, @@ -1915,20 +1936,7 @@ class SyncHandler: sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) - # 5. Work out which users have joined or left rooms we're in. We use this - # to build the device_list part of the sync response in - # `_generate_sync_entry_for_device_list`. - ( - newly_joined_or_invited_or_knocked_users, - newly_left_users, - ) = sync_result_builder.calculate_user_changes() - - return ( - set(newly_joined_rooms), - newly_joined_or_invited_or_knocked_users, - set(newly_left_rooms), - newly_left_users, - ) + return set(newly_joined_rooms), set(newly_left_rooms) async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" -- cgit 1.5.1 From 8e9fc28c6aff6bb1aa960dfde4f9736fee1ae4fb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Feb 2023 08:27:31 -0500 Subject: Reload the pyo3-log config when the Python logging config changes. (#14976) Since pyo3-log is initialized very early in the Python start-up it caches the state of the loggers before they're fully initialized (and thus are essentially disabled). Whenever we reload the logging configuration we now also tell pyo3-log to discard any cached logging configuration it has; it will refetch the current logging configuration from Python at the next point it logs. This fixes Rust log lines not appearing in the homeserver logs. --- changelog.d/14976.bugfix | 1 + rust/src/lib.rs | 17 +++++++++++-- stubs/synapse/synapse_rust/__init__.pyi | 1 + synapse/config/logger.py | 42 +++++++++++++++++++-------------- tests/test_utils/logging_setup.py | 3 +++ 5 files changed, 44 insertions(+), 20 deletions(-) create mode 100644 changelog.d/14976.bugfix (limited to 'synapse') diff --git a/changelog.d/14976.bugfix b/changelog.d/14976.bugfix new file mode 100644 index 0000000000..0cde046c0e --- /dev/null +++ b/changelog.d/14976.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.68.0 where logging from the Rust module was not properly logged. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c7b60e58a7..ce67f58611 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,7 +1,13 @@ +use lazy_static::lazy_static; use pyo3::prelude::*; +use pyo3_log::ResetHandle; pub mod push; +lazy_static! { + static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init(); +} + /// Returns the hash of all the rust source files at the time it was compiled. /// /// Used by python to detect if the rust library is outdated. @@ -17,13 +23,20 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { Ok((a + b).to_string()) } +/// Reset the cached logging configuration of pyo3-log to pick up any changes +/// in the Python logging configuration. +/// +#[pyfunction] +fn reset_logging_config() { + LOGGING_HANDLE.reset(); +} + /// The entry point for defining the Python module. #[pymodule] fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { - pyo3_log::init(); - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; + m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?; push::register_module(py, m)?; diff --git a/stubs/synapse/synapse_rust/__init__.pyi b/stubs/synapse/synapse_rust/__init__.pyi index 8658d3138f..d25c609106 100644 --- a/stubs/synapse/synapse_rust/__init__.pyi +++ b/stubs/synapse/synapse_rust/__init__.pyi @@ -1,2 +1,3 @@ def sum_as_string(a: int, b: int) -> str: ... def get_rust_file_digest() -> str: ... +def reset_logging_config() -> None: ... diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 5468b963a2..56db875b25 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -34,6 +34,7 @@ from twisted.logger import ( from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter +from synapse.synapse_rust import reset_logging_config from synapse.types import JsonDict from ..util import SYNAPSE_VERSION @@ -200,24 +201,6 @@ def _setup_stdlib_logging( """ Set up Python standard library logging. """ - if log_config_path is None: - log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" - " - %(message)s" - ) - - logger = logging.getLogger("") - logger.setLevel(logging.INFO) - logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO) - - formatter = logging.Formatter(log_format) - - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - else: - # Load the logging configuration. - _load_logging_config(log_config_path) # We add a log record factory that runs all messages through the # LoggingContextFilter so that we get the context *at the time we log* @@ -237,6 +220,26 @@ def _setup_stdlib_logging( logging.setLogRecordFactory(factory) + # Configure the logger with the initial configuration. + if log_config_path is None: + log_format = ( + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" + " - %(message)s" + ) + + logger = logging.getLogger("") + logger.setLevel(logging.INFO) + logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO) + + formatter = logging.Formatter(log_format) + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + else: + # Load the logging configuration. + _load_logging_config(log_config_path) + # Route Twisted's native logging through to the standard library logging # system. observer = STDLibLogObserver() @@ -294,6 +297,9 @@ def _load_logging_config(log_config_path: str) -> None: logging.config.dictConfig(log_config) + # Blow away the pyo3-log cache so that it reloads the configuration. + reset_logging_config() + def _reload_logging_config(log_config_path: Optional[str]) -> None: """ diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 9228454c9e..304c7b98c5 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -17,6 +17,7 @@ import os import twisted.logger from synapse.logging.context import LoggingContextFilter +from synapse.synapse_rust import reset_logging_config class ToTwistedHandler(logging.Handler): @@ -52,3 +53,5 @@ def setup_logging(): log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") root_logger.setLevel(log_level) + + reset_logging_config() -- cgit 1.5.1 From 0a686d1d13c497af84f62ca192a401fdc18387ab Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:39:59 +0000 Subject: Faster joins: Refactor handling of servers in room (#14954) Ensure that the list of servers in a partial state room always contains the server we joined off. Also refactor `get_partial_state_servers_at_join` to return `None` when the given room is no longer partial stated, to explicitly indicate when the room has partial state. Otherwise it's not clear whether an empty list means that the room has full state, or the room is partial stated, but the server we joined off told us that there are no servers in the room. Signed-off-by: Sean Quah --- changelog.d/14954.misc | 1 + synapse/federation/federation_client.py | 33 ++++++++++++++-------- synapse/federation/sender/__init__.py | 2 +- synapse/handlers/device.py | 1 + synapse/handlers/federation.py | 20 +++++++++---- synapse/storage/controllers/state.py | 3 +- synapse/storage/databases/main/room.py | 50 ++++++++++++++++++++++----------- tests/handlers/test_federation.py | 2 +- tests/handlers/test_room_member.py | 2 +- 9 files changed, 77 insertions(+), 37 deletions(-) create mode 100644 changelog.d/14954.misc (limited to 'synapse') diff --git a/changelog.d/14954.misc b/changelog.d/14954.misc new file mode 100644 index 0000000000..b86b6bf01e --- /dev/null +++ b/changelog.d/14954.misc @@ -0,0 +1 @@ +Faster room joins: Refactor internal handling of servers in room to never store an empty list. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 8493ffc2e5..0ac85a3be7 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,6 +19,7 @@ import itertools import logging from typing import ( TYPE_CHECKING, + AbstractSet, Awaitable, Callable, Collection, @@ -110,8 +111,9 @@ class SendJoinResult: # True if 'state' elides non-critical membership events partial_state: bool - # if 'partial_state' is set, a list of the servers in the room (otherwise empty) - servers_in_room: List[str] + # If 'partial_state' is set, a set of the servers in the room (otherwise empty). + # Always contains the server we joined off. + servers_in_room: AbstractSet[str] class FederationClient(FederationBase): @@ -1152,15 +1154,24 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) - if response.members_omitted and not response.servers_in_room: - raise InvalidResponseError( - "members_omitted was set, but no servers were listed in the room" - ) + servers_in_room = None + if response.servers_in_room is not None: + servers_in_room = set(response.servers_in_room) - if response.members_omitted and not partial_state: - raise InvalidResponseError( - "members_omitted was set, but we asked for full state" - ) + if response.members_omitted: + if not servers_in_room: + raise InvalidResponseError( + "members_omitted was set, but no servers were listed in the room" + ) + + if not partial_state: + raise InvalidResponseError( + "members_omitted was set, but we asked for full state" + ) + + # `servers_in_room` is supposed to be a complete list. + # Fix things up in case the remote homeserver is badly behaved. + servers_in_room.add(destination) return SendJoinResult( event=event, @@ -1168,7 +1179,7 @@ class FederationClient(FederationBase): auth_chain=signed_auth, origin=destination, partial_state=response.members_omitted, - servers_in_room=response.servers_in_room or [], + servers_in_room=servers_in_room or frozenset(), ) # MSC3083 defines additional error codes for room joins. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 30ebd62883..43421a9c72 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -447,7 +447,7 @@ class FederationSender(AbstractFederationSender): ) ) - if len(partial_state_destinations) > 0: + if partial_state_destinations is not None: destinations = partial_state_destinations if destinations is None: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5c06073901..6f7963df43 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -859,6 +859,7 @@ class DeviceHandler(DeviceWorkerHandler): known_hosts_at_join = await self.store.get_partial_state_servers_at_join( room_id ) + assert known_hosts_at_join is not None potentially_changed_hosts.difference_update(known_hosts_at_join) potentially_changed_hosts.discard(self.server_name) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dc1cbf5c3d..7f64130e0a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,7 +20,17 @@ import itertools import logging from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + AbstractSet, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr from prometheus_client import Histogram @@ -169,7 +179,7 @@ class FederationHandler: # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: Dict[ - str, Tuple[Optional[str], StrCollection] + str, Tuple[Optional[str], AbstractSet[str]] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -1720,7 +1730,7 @@ class FederationHandler: def _start_partial_state_room_sync( self, initial_destination: Optional[str], - other_destinations: StrCollection, + other_destinations: AbstractSet[str], room_id: str, ) -> None: """Starts the background process to resync the state of a partial state room, @@ -1802,7 +1812,7 @@ class FederationHandler: async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: StrCollection, + other_destinations: AbstractSet[str], room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1939,7 +1949,7 @@ class FederationHandler: def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: StrCollection, + other_destinations: AbstractSet[str], room_id: str, ) -> StrCollection: """Work out the order in which we should ask servers to resync events. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2045169b9a..52efd4a171 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -569,10 +569,11 @@ class StateStorageController: is arbitrary for rooms with partial state. """ # We have to read this list first to mitigate races with un-partial stating. - # This will be empty for rooms with full state. hosts_at_join = await self.stores.main.get_partial_state_servers_at_join( room_id ) + if hosts_at_join is None: + hosts_at_join = frozenset() hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4ddb27f686..644bbb8878 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -18,6 +18,7 @@ from abc import abstractmethod from enum import Enum from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Collection, @@ -25,7 +26,6 @@ from typing import ( List, Mapping, Optional, - Sequence, Set, Tuple, Union, @@ -109,7 +109,7 @@ class RoomSortOrder(Enum): @attr.s(slots=True, frozen=True, auto_attribs=True) class PartialStateResyncInfo: joined_via: Optional[str] - servers_in_room: List[str] = attr.ib(factory=list) + servers_in_room: Set[str] = attr.ib(factory=set) class RoomWorkerStore(CacheInvalidationWorkerStore): @@ -1193,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_rooms_for_retention_period_in_range_txn, ) - @cached(iterable=True) - async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]: - """Gets the list of servers in a partial state room at the time we joined it. + async def get_partial_state_servers_at_join( + self, room_id: str + ) -> Optional[AbstractSet[str]]: + """Gets the set of servers in a partial state room at the time we joined it. Returns: The `servers_in_room` list from the `/send_join` response for partial state rooms. May not be accurate or complete, as it comes from a remote homeserver. - An empty list for full state rooms. + `None` for full state rooms. """ - return await self.db_pool.simple_select_onecol( - "partial_state_rooms_servers", - keyvalues={"room_id": room_id}, - retcol="server_name", - desc="get_partial_state_servers_at_join", + servers_in_room = await self._get_partial_state_servers_at_join(room_id) + + if len(servers_in_room) == 0: + return None + + return servers_in_room + + @cached(iterable=True) + async def _get_partial_state_servers_at_join( + self, room_id: str + ) -> AbstractSet[str]: + return frozenset( + await self.db_pool.simple_select_onecol( + "partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + retcol="server_name", + desc="get_partial_state_servers_at_join", + ) ) async def get_partial_state_room_resync_info( @@ -1252,7 +1266,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # partial-joined between the two SELECTs, but this is unlikely to happen # in practice.) continue - entry.servers_in_room.append(server_name) + entry.servers_in_room.add(server_name) return room_servers @@ -1942,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): async def store_partial_state_room( self, room_id: str, - servers: Collection[str], + servers: AbstractSet[str], device_lists_stream_id: int, joined_via: str, ) -> None: @@ -1957,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): Args: room_id: the ID of the room - servers: other servers known to be in the room + servers: other servers known to be in the room. must include `joined_via`. device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. joined_via: the server name we requested a partial join from. """ + assert joined_via in servers + await self.db_pool.runInteraction( "store_partial_state_room", self._store_partial_state_room_txn, @@ -1975,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self, txn: LoggingTransaction, room_id: str, - servers: Collection[str], + servers: AbstractSet[str], device_lists_stream_id: int, joined_via: str, ) -> None: @@ -1998,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( - txn, self.get_partial_state_servers_at_join, (room_id,) + txn, self._get_partial_state_servers_at_join, (room_id,) ) async def write_partial_state_rooms_join_event_id( @@ -2409,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( - txn, self.get_partial_state_servers_at_join, (room_id,) + txn, self._get_partial_state_servers_at_join, (room_id,) ) DatabasePool.simple_insert_txn( diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c1558c40c3..57675fa407 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -656,7 +656,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): EVENT_INVITATION_MEMBERSHIP, ], partial_state=True, - servers_in_room=["example.com"], + servers_in_room={"example.com"}, ) ) ) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 6bbfd5dc84..6a38893b68 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -171,7 +171,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): state=[create_event], auth_chain=[create_event], partial_state=False, - servers_in_room=[], + servers_in_room=frozenset(), ) ) ) -- cgit 1.5.1 From 52700a0bcf2caaa792b94e2a8c12f29d1c61b91e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Feb 2023 11:28:20 -0500 Subject: Support the backwards compatibility features in MSC3952. (#14958) If the feature is enabled and the event has a `m.mentions` property, skip processing of the legacy mentions rules. --- changelog.d/14958.feature | 1 + rust/benches/evaluator.rs | 4 + rust/src/push/evaluator.rs | 19 +++ stubs/synapse/synapse_rust/push.pyi | 1 + synapse/push/bulk_push_rule_evaluator.py | 9 +- tests/push/test_bulk_push_rule_evaluator.py | 191 ++++++++++++++++++++-------- tests/push/test_push_rule_evaluator.py | 18 ++- 7 files changed, 184 insertions(+), 59 deletions(-) create mode 100644 changelog.d/14958.feature (limited to 'synapse') diff --git a/changelog.d/14958.feature b/changelog.d/14958.feature new file mode 100644 index 0000000000..8293e99eff --- /dev/null +++ b/changelog.d/14958.feature @@ -0,0 +1 @@ +Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 6b16a3f75b..859d54961c 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -33,6 +33,7 @@ fn bench_match_exact(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, BTreeSet::new(), false, 10, @@ -71,6 +72,7 @@ fn bench_match_word(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, BTreeSet::new(), false, 10, @@ -109,6 +111,7 @@ fn bench_match_word_miss(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, BTreeSet::new(), false, 10, @@ -147,6 +150,7 @@ fn bench_eval_message(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + false, BTreeSet::new(), false, 10, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index aa71202e43..da6f704c0e 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -68,6 +68,8 @@ pub struct PushRuleEvaluator { /// The "content.body", if any. body: String, + /// True if the event has a mentions property and MSC3952 support is enabled. + has_mentions: bool, /// The user mentions that were part of the message. user_mentions: BTreeSet, /// True if the message is a room message. @@ -105,6 +107,7 @@ impl PushRuleEvaluator { #[new] pub fn py_new( flattened_keys: BTreeMap, + has_mentions: bool, user_mentions: BTreeSet, room_mention: bool, room_member_count: u64, @@ -123,6 +126,7 @@ impl PushRuleEvaluator { Ok(PushRuleEvaluator { flattened_keys, body, + has_mentions, user_mentions, room_mention, room_member_count, @@ -155,6 +159,19 @@ impl PushRuleEvaluator { } let rule_id = &push_rule.rule_id().to_string(); + + // For backwards-compatibility the legacy mention rules are disabled + // if the event contains the 'm.mentions' property (and if the + // experimental feature is enabled, both of these are represented + // by the has_mentions flag). + if self.has_mentions + && (rule_id == "global/override/.m.rule.contains_display_name" + || rule_id == "global/content/.m.rule.contains_user_name" + || rule_id == "global/override/.m.rule.roomnotif") + { + continue; + } + let extev_flag = &RoomVersionFeatures::ExtensibleEvents.as_str().to_string(); let supports_extensible_events = self.room_version_feature_flags.contains(extev_flag); let safe_from_rver_condition = SAFE_EXTENSIBLE_EVENTS_RULE_IDS.contains(rule_id); @@ -441,6 +458,7 @@ fn push_rule_evaluator() { flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); let evaluator = PushRuleEvaluator::py_new( flattened_keys, + false, BTreeSet::new(), false, 10, @@ -468,6 +486,7 @@ fn test_requires_room_version_supports_condition() { let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, + false, BTreeSet::new(), false, 10, diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 588d90c25a..c0af2af3df 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -56,6 +56,7 @@ class PushRuleEvaluator: def __init__( self, flattened_keys: Mapping[str, str], + has_mentions: bool, user_mentions: Set[str], room_mention: bool, room_member_count: int, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 88cfc05d05..9bf92b9765 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -119,6 +119,9 @@ class BulkPushRuleEvaluator: self.should_calculate_push_rules = self.hs.config.push.enable_push self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self._intentional_mentions_enabled = ( + self.hs.config.experimental.msc3952_intentional_mentions + ) self.room_push_rule_cache_metrics = register_cache( "cache", @@ -364,9 +367,12 @@ class BulkPushRuleEvaluator: # Pull out any user and room mentions. mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) + has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict) user_mentions: Set[str] = set() room_mention = False - if isinstance(mentions, dict): + if has_mentions: + # mypy seems to have lost the type even though it must be a dict here. + assert isinstance(mentions, dict) # Remove out any non-string items and convert to a set. user_mentions_raw = mentions.get("user_ids") if isinstance(user_mentions_raw, list): @@ -378,6 +384,7 @@ class BulkPushRuleEvaluator: evaluator = PushRuleEvaluator( _flatten_dict(event, room_version=event.room_version), + has_mentions, user_mentions, room_mention, room_member_count, diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index fda48d9f61..3b2d082dcb 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional from unittest.mock import patch from parameterized import parameterized @@ -25,7 +25,7 @@ from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.server import HomeServer -from synapse.types import create_requester +from synapse.types import JsonDict, create_requester from synapse.util import Clock from tests.test_utils import simple_async_mock @@ -196,77 +196,144 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) bulk_evaluator._action_for_event_by_user.assert_not_called() - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) - def test_mentions(self) -> None: - """Test the behavior of an event which includes invalid mentions.""" - bulk_evaluator = BulkPushRuleEvaluator(self.hs) - - sentinel = object() - - def create_and_process(mentions: Any = sentinel) -> bool: - """Returns true iff the `mentions` trigger an event push action.""" - content = {} - if mentions is not sentinel: - content[EventContentFields.MSC3952_MENTIONS] = mentions - - # Create a new message event which should cause a notification. - event, context = self.get_success( - self.event_creation_handler.create_event( - self.requester, - { - "type": "test", - "room_id": self.room_id, - "content": content, - "sender": f"@bob:{self.hs.hostname}", - }, - ) + def _create_and_process( + self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None + ) -> bool: + """Returns true iff the `mentions` trigger an event push action.""" + # Create a new message event which should cause a notification. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "test", + "room_id": self.room_id, + "content": content or {}, + "sender": f"@bob:{self.hs.hostname}", + }, ) + ) - # Ensure no actions are generated! - self.get_success( - bulk_evaluator.action_for_events_by_user([(event, context)]) - ) + # Execute the push rule machinery. + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) - # If any actions are generated for this event, return true. - result = self.get_success( - self.hs.get_datastores().main.db_pool.simple_select_list( - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcols=("*",), - desc="get_event_push_actions_staging", - ) + # If any actions are generated for this event, return true. + result = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcols=("*",), + desc="get_event_push_actions_staging", ) - return len(result) > 0 + ) + return len(result) > 0 + + @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + def test_user_mentions(self) -> None: + """Test the behavior of an event which includes invalid user mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Not including the mentions field should not notify. - self.assertFalse(create_and_process()) + self.assertFalse(self._create_and_process(bulk_evaluator)) # An empty mentions field should not notify. - self.assertFalse(create_and_process({})) + self.assertFalse( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {}} + ) + ) # Non-dict mentions should be ignored. mentions: Any for mentions in (None, True, False, 1, "foo", []): - self.assertFalse(create_and_process(mentions)) + self.assertFalse( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: mentions} + ) + ) # A non-list should be ignored. for mentions in (None, True, False, 1, "foo", {}): - self.assertFalse(create_and_process({"user_ids": mentions})) + self.assertFalse( + self._create_and_process( + bulk_evaluator, + {EventContentFields.MSC3952_MENTIONS: {"user_ids": mentions}}, + ) + ) # The Matrix ID appearing anywhere in the list should notify. - self.assertTrue(create_and_process({"user_ids": [self.alice]})) - self.assertTrue(create_and_process({"user_ids": ["@another:test", self.alice]})) + self.assertTrue( + self._create_and_process( + bulk_evaluator, + {EventContentFields.MSC3952_MENTIONS: {"user_ids": [self.alice]}}, + ) + ) + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": ["@another:test", self.alice] + } + }, + ) + ) # Duplicate user IDs should notify. - self.assertTrue(create_and_process({"user_ids": [self.alice, self.alice]})) + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": [self.alice, self.alice] + } + }, + ) + ) # Invalid entries in the list are ignored. - self.assertFalse(create_and_process({"user_ids": [None, True, False, {}, []]})) + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": [None, True, False, {}, []] + } + }, + ) + ) self.assertTrue( - create_and_process({"user_ids": [None, True, False, {}, [], self.alice]}) + self._create_and_process( + bulk_evaluator, + { + EventContentFields.MSC3952_MENTIONS: { + "user_ids": [None, True, False, {}, [], self.alice] + } + }, + ) ) + # The legacy push rule should not mention if the mentions field exists. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "body": self.alice, + "msgtype": "m.text", + EventContentFields.MSC3952_MENTIONS: {}, + }, + ) + ) + + @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + def test_room_mentions(self) -> None: + """Test the behavior of an event which includes invalid room mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + # Room mentions from those without power should not notify. - self.assertFalse(create_and_process({"room": True})) + self.assertFalse( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}} + ) + ) # Room mentions from those with power should notify. self.helper.send_state( @@ -276,8 +343,30 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.token, state_key="", ) - self.assertTrue(create_and_process({"room": True})) + self.assertTrue( + self._create_and_process( + bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}} + ) + ) # Invalid data should not notify. + mentions: Any for mentions in (None, False, 1, "foo", [], {}): - self.assertFalse(create_and_process({"room": mentions})) + self.assertFalse( + self._create_and_process( + bulk_evaluator, + {EventContentFields.MSC3952_MENTIONS: {"room": mentions}}, + ) + ) + + # The legacy push rule should not mention if the mentions field exists. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "body": "@room", + "msgtype": "m.text", + EventContentFields.MSC3952_MENTIONS: {}, + }, + ) + ) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 9d01c989d4..81661e181b 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -42,6 +42,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): self, content: JsonMapping, *, + has_mentions: bool = False, user_mentions: Optional[Set[str]] = None, room_mention: bool = False, related_events: Optional[JsonDict] = None, @@ -62,6 +63,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluator( _flatten_dict(event), + has_mentions, user_mentions or set(), room_mention, room_member_count, @@ -102,19 +104,21 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): condition = {"kind": "org.matrix.msc3952.is_user_mention"} # No mentions shouldn't match. - evaluator = self._get_evaluator({}) + evaluator = self._get_evaluator({}, has_mentions=True) self.assertFalse(evaluator.matches(condition, "@user:test", None)) # An empty set shouldn't match - evaluator = self._get_evaluator({}, user_mentions=set()) + evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set()) self.assertFalse(evaluator.matches(condition, "@user:test", None)) # The Matrix ID appearing anywhere in the mentions list should match - evaluator = self._get_evaluator({}, user_mentions={"@user:test"}) + evaluator = self._get_evaluator( + {}, has_mentions=True, user_mentions={"@user:test"} + ) self.assertTrue(evaluator.matches(condition, "@user:test", None)) evaluator = self._get_evaluator( - {}, user_mentions={"@another:test", "@user:test"} + {}, has_mentions=True, user_mentions={"@another:test", "@user:test"} ) self.assertTrue(evaluator.matches(condition, "@user:test", None)) @@ -126,16 +130,16 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): condition = {"kind": "org.matrix.msc3952.is_room_mention"} # No room mention shouldn't match. - evaluator = self._get_evaluator({}) + evaluator = self._get_evaluator({}, has_mentions=True) self.assertFalse(evaluator.matches(condition, None, None)) # Room mention should match. - evaluator = self._get_evaluator({}, room_mention=True) + evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True) self.assertTrue(evaluator.matches(condition, None, None)) # A room mention and user mention is valid. evaluator = self._get_evaluator( - {}, user_mentions={"@another:test"}, room_mention=True + {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True ) self.assertTrue(evaluator.matches(condition, None, None)) -- cgit 1.5.1 From f0cae26d58f6f907236112be5f4eaecc376b1304 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Feb 2023 11:48:13 -0500 Subject: Add a docstring & tests for _flatten_dict. (#14981) --- changelog.d/14981.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 23 +++++++++++++++++++++++ tests/push/test_push_rule_evaluator.py | 26 +++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14981.misc (limited to 'synapse') diff --git a/changelog.d/14981.misc b/changelog.d/14981.misc new file mode 100644 index 0000000000..68ac8335fc --- /dev/null +++ b/changelog.d/14981.misc @@ -0,0 +1 @@ +Add tests for `_flatten_dict`. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9bf92b9765..20369f3dfe 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -473,6 +473,29 @@ def _flatten_dict( prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: + """ + Given a JSON dictionary (or event) which might contain sub dictionaries, + flatten it into a single layer dictionary by combining the keys & sub-keys. + + Any (non-dictionary), non-string value is dropped. + + Transforms: + + {"foo": {"bar": "test"}} + + To: + + {"foo.bar": "test"} + + Args: + d: The event or content to continue flattening. + room_version: The room version object. + prefix: The key prefix (from outer dictionaries). + result: The result to mutate. + + Returns: + The resulting dictionary. + """ if prefix is None: prefix = [] if result is None: diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 81661e181b..7c430c4ecb 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Set, Union, cast +from typing import Any, Dict, List, Optional, Set, Union, cast import frozendict @@ -37,6 +37,30 @@ from tests import unittest from tests.test_utils.event_injection import create_event, inject_member_event +class FlattenDictTestCase(unittest.TestCase): + def test_simple(self) -> None: + """Test a dictionary that isn't modified.""" + input = {"foo": "abc"} + self.assertEqual(input, _flatten_dict(input)) + + def test_nested(self) -> None: + """Nested dictionaries become dotted paths.""" + input = {"foo": {"bar": "abc"}} + self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input)) + + def test_non_string(self) -> None: + """Non-string items are dropped.""" + input: Dict[str, Any] = { + "woo": "woo", + "foo": True, + "bar": 1, + "baz": None, + "fuzz": [], + "boo": {}, + } + self.assertEqual({"woo": "woo"}, _flatten_dict(input)) + + class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( self, -- cgit 1.5.1 From b2d97bac0910c4730ea83fbee50abbdce2ba23be Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Feb 2023 14:31:14 -0500 Subject: Implement MSC3958: suppress notifications from edits (#14960) Co-authored-by: Brad Murray Co-authored-by: Nick Barrett Copy the suppress_edits push rule from Beeper to implement MSC3958. https://github.com/beeper/synapse/blame/9415a1284b1bfb558bd66f28c24ca1611e6c6fa2/rust/src/push/base_rules.rs#L98-L114 --- changelog.d/14960.feature | 1 + rust/benches/evaluator.rs | 1 + rust/src/push/base_rules.rs | 17 ++++++++++++ rust/src/push/evaluator.rs | 2 +- rust/src/push/mod.rs | 8 ++++++ stubs/synapse/synapse_rust/push.pyi | 1 + synapse/config/experimental.py | 5 ++++ synapse/storage/databases/main/push_rule.py | 1 + tests/push/test_bulk_push_rule_evaluator.py | 42 ++++++++++++++++++++++++++++- 9 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14960.feature (limited to 'synapse') diff --git a/changelog.d/14960.feature b/changelog.d/14960.feature new file mode 100644 index 0000000000..b9bb331273 --- /dev/null +++ b/changelog.d/14960.feature @@ -0,0 +1 @@ +Experimental support to suppress notifications from message edits ([MSC3958](https://github.com/matrix-org/matrix-spec-proposals/pull/3958)). diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 859d54961c..35f7a50bce 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -170,6 +170,7 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 49add4e951..e9af26dd4f 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -63,6 +63,23 @@ pub const BASE_PREPEND_OVERRIDE_RULES: &[PushRule] = &[PushRule { }]; pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ + // We don't want to notify on edits. Not only can this be confusing in real + // time (2 notifications, one message) but it's especially confusing + // if a bridge needs to edit a previously backfilled message. + PushRule { + rule_id: Cow::Borrowed("global/override/.com.beeper.suppress_edits"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("content.m.relates_to.rel_type"), + pattern: Some(Cow::Borrowed("m.replace")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.suppress_notices"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index da6f704c0e..ec7a8c4453 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -523,7 +523,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 7e449f2433..3c4f876cab 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -419,6 +419,7 @@ pub struct FilteredPushRules { msc3381_polls_enabled: bool, msc3664_enabled: bool, msc3952_intentional_mentions: bool, + msc3958_suppress_edits_enabled: bool, } #[pymethods] @@ -431,6 +432,7 @@ impl FilteredPushRules { msc3381_polls_enabled: bool, msc3664_enabled: bool, msc3952_intentional_mentions: bool, + msc3958_suppress_edits_enabled: bool, ) -> Self { Self { push_rules, @@ -439,6 +441,7 @@ impl FilteredPushRules { msc3381_polls_enabled, msc3664_enabled, msc3952_intentional_mentions, + msc3958_suppress_edits_enabled, } } @@ -476,6 +479,11 @@ impl FilteredPushRules { { return false; } + if !self.msc3958_suppress_edits_enabled + && rule.rule_id == "global/override/.com.beeper.suppress_edits" + { + return false; + } true }) diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index c0af2af3df..754acab2f9 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -47,6 +47,7 @@ class FilteredPushRules: msc3381_polls_enabled: bool, msc3664_enabled: bool, msc3952_intentional_mentions: bool, + msc3958_suppress_edits_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d2d0270ddd..53c0682dfd 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -173,3 +173,8 @@ class ExperimentalConfig(Config): self.msc3952_intentional_mentions = experimental.get( "msc3952_intentional_mentions", False ) + + # MSC3959: Do not generate notifications for edits. + self.msc3958_supress_edit_notifs = experimental.get( + "msc3958_supress_edit_notifs", False + ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 466a1145b7..9b2bbe060d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -90,6 +90,7 @@ def _load_rules( msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions, + msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, ) return filtered_rules diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 3b2d082dcb..7567756135 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -19,7 +19,7 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EventContentFields +from synapse.api.constants import EventContentFields, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin @@ -370,3 +370,43 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + + @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}}) + def test_suppress_edits(self) -> None: + """Under the default push rules, event edits should not generate notifications.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + # Create & persist an event to use as the parent of the relation. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "m.room.message", + "room_id": self.room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": self.alice, + }, + ) + ) + self.get_success( + self.event_creation_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) + ) + + # Room mentions from those without power should not notify. + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "body": self.alice, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event.event_id, + }, + }, + ) + ) -- cgit 1.5.1 From 6e6edea6c15dc1a15f44d9e92d334e3ce0f827dd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Feb 2023 20:03:23 +0000 Subject: Properly typecheck tests.api (#14983) --- changelog.d/14983.misc | 1 + mypy.ini | 4 +- synapse/api/filtering.py | 4 +- tests/api/test_auth.py | 64 +++++++++-------- tests/api/test_filtering.py | 157 +++++++++++++++++++++++------------------ tests/api/test_ratelimiting.py | 18 ++--- tests/events/test_utils.py | 2 + 7 files changed, 140 insertions(+), 110 deletions(-) create mode 100644 changelog.d/14983.misc (limited to 'synapse') diff --git a/changelog.d/14983.misc b/changelog.d/14983.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/14983.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 57f43395bb..a6e37bc377 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,7 +32,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - |tests/api/test_auth.py |tests/appservice/test_scheduler.py |tests/federation/test_federation_catch_up.py |tests/federation/test_federation_sender.py @@ -73,6 +72,9 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.api.*] +disallow_untyped_defs = True + [mypy-tests.app.*] disallow_untyped_defs = True diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 2b5af264b4..83c42fc25a 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -252,9 +252,9 @@ class FilterCollection: return self._room_timeline_filter.unread_thread_notifications async def filter_presence( - self, events: Iterable[UserPresenceState] + self, presence_states: Iterable[UserPresenceState] ) -> List[UserPresenceState]: - return await self._presence_filter.filter(events) + return await self._presence_filter.filter(presence_states) async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: return await self._account_data.filter(events) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e0f363555b..6e36e73f0d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -31,7 +31,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester +from synapse.types import Requester, UserID from synapse.util import Clock from tests import unittest @@ -41,10 +41,12 @@ from tests.utils import mock_getRawHeaders class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = Mock() - hs.datastores.main = self.store + # type-ignore: datastores is None until hs.setup() is called---but it'll + # have been called by the HomeserverTestCase machinery. + hs.datastores.main = self.store # type: ignore[union-attr] hs.get_auth_handler().store = self.store self.auth = Auth(hs) @@ -61,7 +63,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.insert_client_ip = simple_async_mock(None) self.store.is_support_user = simple_async_mock(False) - def test_get_user_by_req_user_valid_token(self): + def test_get_user_by_req_user_valid_token(self) -> None: user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) @@ -74,7 +76,7 @@ class AuthTestCase(unittest.HomeserverTestCase): requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(requester.user.to_string(), self.test_user) - def test_get_user_by_req_user_bad_token(self): + def test_get_user_by_req_user_bad_token(self) -> None: self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) @@ -86,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") - def test_get_user_by_req_user_missing_token(self): + def test_get_user_by_req_user_missing_token(self) -> None: user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = simple_async_mock(user_info) @@ -98,7 +100,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - def test_get_user_by_req_appservice_valid_token(self): + def test_get_user_by_req_appservice_valid_token(self) -> None: app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) @@ -112,7 +114,7 @@ class AuthTestCase(unittest.HomeserverTestCase): requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(requester.user.to_string(), self.test_user) - def test_get_user_by_req_appservice_valid_token_good_ip(self): + def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: from netaddr import IPSet app_service = Mock( @@ -131,7 +133,7 @@ class AuthTestCase(unittest.HomeserverTestCase): requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(requester.user.to_string(), self.test_user) - def test_get_user_by_req_appservice_valid_token_bad_ip(self): + def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: from netaddr import IPSet app_service = Mock( @@ -153,7 +155,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") - def test_get_user_by_req_appservice_bad_token(self): + def test_get_user_by_req_appservice_bad_token(self) -> None: self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = simple_async_mock(None) @@ -166,7 +168,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") - def test_get_user_by_req_appservice_missing_token(self): + def test_get_user_by_req_appservice_missing_token(self) -> None: app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = simple_async_mock(None) @@ -179,7 +181,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - def test_get_user_by_req_appservice_valid_token_valid_user_id(self): + def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None @@ -200,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase): requester.user.to_string(), masquerading_user_id.decode("utf8") ) - def test_get_user_by_req_appservice_valid_token_bad_user_id(self): + def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None @@ -217,7 +219,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.get_failure(self.auth.get_user_by_req(request), AuthError) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) - def test_get_user_by_req_appservice_valid_token_valid_device_id(self): + def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None: """ Tests that when an application service passes the device_id URL parameter with the ID of a valid device for the user in question, @@ -249,7 +251,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) - def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): + def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: """ Tests that when an application service passes the device_id URL parameter with an ID that is not a valid device ID for the user in question, @@ -279,7 +281,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.value.code, 400) self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) - def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): + def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult( user_id="@baldrick:matrix.org", @@ -298,7 +300,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.get_success(self.auth.get_user_by_req(request)) self.store.insert_client_ip.assert_called_once() - def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self): + def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.auth._track_puppeted_user_ips = True self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult( @@ -318,7 +320,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.get_success(self.auth.get_user_by_req(request)) self.assertEqual(self.store.insert_client_ip.call_count, 2) - def test_get_user_from_macaroon(self): + def test_get_user_from_macaroon(self) -> None: self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" @@ -336,7 +338,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth.get_user_by_access_token(serialized), InvalidClientTokenError ) - def test_get_guest_user_from_macaroon(self): + def test_get_guest_user_from_macaroon(self) -> None: self.store.get_user_by_id = simple_async_mock({"is_guest": True}) self.store.get_user_by_access_token = simple_async_mock(None) @@ -357,7 +359,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) - def test_blocking_mau(self): + def test_blocking_mau(self) -> None: self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._max_mau_value = 50 lots_of_users = 100 @@ -381,7 +383,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) self.get_success(self.auth_blocking.check_auth_blocking()) - def test_blocking_mau__depending_on_user_type(self): + def test_blocking_mau__depending_on_user_type(self) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True @@ -400,7 +402,9 @@ class AuthTestCase(unittest.HomeserverTestCase): # Real users not allowed self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) - def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): + def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips( + self, + ) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = False @@ -418,7 +422,7 @@ class AuthTestCase(unittest.HomeserverTestCase): sender="@appservice:sender", ) requester = Requester( - user="@appservice:server", + user=UserID.from_string("@appservice:server"), access_token_id=None, device_id="FOOBAR", is_guest=False, @@ -428,7 +432,9 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.get_success(self.auth_blocking.check_auth_blocking(requester=requester)) - def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): + def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( + self, + ) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = True @@ -446,7 +452,7 @@ class AuthTestCase(unittest.HomeserverTestCase): sender="@appservice:sender", ) requester = Requester( - user="@appservice:server", + user=UserID.from_string("@appservice:server"), access_token_id=None, device_id="FOOBAR", is_guest=False, @@ -459,7 +465,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - def test_reserved_threepid(self): + def test_reserved_threepid(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 self.store.get_monthly_active_count = simple_async_mock(2) @@ -476,7 +482,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid)) - def test_hs_disabled(self): + def test_hs_disabled(self) -> None: self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure( @@ -486,7 +492,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) - def test_hs_disabled_no_server_notices_user(self): + def test_hs_disabled_no_server_notices_user(self) -> None: """Check that 'hs_disabled_message' works correctly when there is no server_notices user. """ @@ -503,7 +509,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) - def test_server_notices_mxid_special_cased(self): + def test_server_notices_mxid_special_cased(self) -> None: self.auth_blocking._hs_disabled = True user = "@user:server" self.auth_blocking._server_notices_mxid = user diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index d5524d296e..0f45615160 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -14,40 +14,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List from unittest.mock import patch import jsonschema from frozendict import frozendict +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events import make_event_from_dict +from synapse.api.presence import UserPresenceState +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest +from tests.events.test_utils import MockEvent user_localpart = "test_user" -def MockEvent(**kwargs): - if "event_id" not in kwargs: - kwargs["event_id"] = "fake_event_id" - if "type" not in kwargs: - kwargs["type"] = "fake_type" - if "content" not in kwargs: - kwargs["content"] = {} - return make_event_from_dict(kwargs) - - class FilteringTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filtering = hs.get_filtering() self.datastore = hs.get_datastores().main - def test_errors_on_invalid_filters(self): + def test_errors_on_invalid_filters(self) -> None: # See USER_FILTER_SCHEMA for the filter schema. - invalid_filters = [ + invalid_filters: List[JsonDict] = [ # `account_data` must be a dictionary {"account_data": "Hello World"}, # `event_fields` entries must not contain backslashes @@ -63,10 +59,10 @@ class FilteringTestCase(unittest.HomeserverTestCase): with self.assertRaises(SynapseError): self.filtering.check_valid_filter(filter) - def test_ignores_unknown_filter_fields(self): + def test_ignores_unknown_filter_fields(self) -> None: # For forward compatibility, we must ignore unknown filter fields. # See USER_FILTER_SCHEMA for the filter schema. - filters = [ + filters: List[JsonDict] = [ {"org.matrix.msc9999.future_option": True}, {"presence": {"org.matrix.msc9999.future_option": True}}, {"room": {"org.matrix.msc9999.future_option": True}}, @@ -76,8 +72,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.filtering.check_valid_filter(filter) # Must not raise. - def test_valid_filters(self): - valid_filters = [ + def test_valid_filters(self) -> None: + valid_filters: List[JsonDict] = [ { "room": { "timeline": {"limit": 20}, @@ -132,22 +128,22 @@ class FilteringTestCase(unittest.HomeserverTestCase): except jsonschema.ValidationError as e: self.fail(e) - def test_limits_are_applied(self): + def test_limits_are_applied(self) -> None: # TODO pass - def test_definition_types_works_with_literals(self): + def test_definition_types_works_with_literals(self) -> None: definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_types_works_with_wildcards(self): + def test_definition_types_works_with_wildcards(self) -> None: definition = {"types": ["m.*", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_types_works_with_unknowns(self): + def test_definition_types_works_with_unknowns(self) -> None: definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent( sender="@foo:bar", @@ -156,24 +152,24 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_works_with_literals(self): + def test_definition_not_types_works_with_literals(self) -> None: definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_works_with_wildcards(self): + def test_definition_not_types_works_with_wildcards(self) -> None: definition = {"not_types": ["m.room.message", "org.matrix.*"]} event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_works_with_unknowns(self): + def test_definition_not_types_works_with_unknowns(self) -> None: definition = {"not_types": ["m.*", "org.*"]} event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_not_types_takes_priority_over_types(self): + def test_definition_not_types_takes_priority_over_types(self) -> None: definition = { "not_types": ["m.*", "org.*"], "types": ["m.room.message", "m.room.topic"], @@ -181,35 +177,35 @@ class FilteringTestCase(unittest.HomeserverTestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_senders_works_with_literals(self): + def test_definition_senders_works_with_literals(self) -> None: definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_senders_works_with_unknowns(self): + def test_definition_senders_works_with_unknowns(self) -> None: definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_senders_works_with_literals(self): + def test_definition_not_senders_works_with_literals(self) -> None: definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_senders_works_with_unknowns(self): + def test_definition_not_senders_works_with_unknowns(self) -> None: definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_not_senders_takes_priority_over_senders(self): + def test_definition_not_senders_takes_priority_over_senders(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets", "@misspiggy:muppets"], @@ -219,14 +215,14 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_rooms_works_with_literals(self): + def test_definition_rooms_works_with_literals(self) -> None: definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_rooms_works_with_unknowns(self): + def test_definition_rooms_works_with_unknowns(self) -> None: definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", @@ -235,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_rooms_works_with_literals(self): + def test_definition_not_rooms_works_with_literals(self) -> None: definition = {"not_rooms": ["!anothersecretbase:unknown"]} event = MockEvent( sender="@foo:bar", @@ -244,7 +240,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_not_rooms_works_with_unknowns(self): + def test_definition_not_rooms_works_with_unknowns(self) -> None: definition = {"not_rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", @@ -253,7 +249,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_not_rooms_takes_priority_over_rooms(self): + def test_definition_not_rooms_takes_priority_over_rooms(self) -> None: definition = { "not_rooms": ["!secretbase:unknown"], "rooms": ["!secretbase:unknown"], @@ -263,7 +259,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event(self): + def test_definition_combined_event(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -279,7 +275,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event_bad_sender(self): + def test_definition_combined_event_bad_sender(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -295,7 +291,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event_bad_room(self): + def test_definition_combined_event_bad_room(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -311,7 +307,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_definition_combined_event_bad_type(self): + def test_definition_combined_event_bad_type(self) -> None: definition = { "not_senders": ["@misspiggy:muppets"], "senders": ["@kermit:muppets"], @@ -327,7 +323,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertFalse(Filter(self.hs, definition)._check(event)) - def test_filter_labels(self): + def test_filter_labels(self) -> None: definition = {"org.matrix.labels": ["#fun"]} event = MockEvent( sender="@foo:bar", @@ -356,7 +352,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_filter_not_labels(self): + def test_filter_not_labels(self) -> None: definition = {"org.matrix.not_labels": ["#fun"]} event = MockEvent( sender="@foo:bar", @@ -377,7 +373,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) - def test_filter_rel_type(self): + def test_filter_rel_type(self) -> None: definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} event = MockEvent( sender="@foo:bar", @@ -407,7 +403,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) - def test_filter_not_rel_type(self): + def test_filter_not_rel_type(self) -> None: definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} event = MockEvent( sender="@foo:bar", @@ -436,15 +432,25 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) - def test_filter_presence_match(self): - user_filter_json = {"presence": {"types": ["m.*"]}} + def test_filter_presence_match(self) -> None: + """Check that filter_presence return events which matches the filter.""" + user_filter_json = {"presence": {"senders": ["@foo:bar"]}} filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) ) - event = MockEvent(sender="@foo:bar", type="m.profile") - events = [event] + presence_states = [ + UserPresenceState( + user_id="@foo:bar", + state="unavailable", + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ), + ] user_filter = self.get_success( self.filtering.get_user_filter( @@ -452,23 +458,29 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEqual(events, results) + results = self.get_success(user_filter.filter_presence(presence_states)) + self.assertEqual(presence_states, results) - def test_filter_presence_no_match(self): - user_filter_json = {"presence": {"types": ["m.*"]}} + def test_filter_presence_no_match(self) -> None: + """Check that filter_presence does not return events rejected by the filter.""" + user_filter_json = {"presence": {"not_senders": ["@foo:bar"]}} filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart + "2", user_filter=user_filter_json ) ) - event = MockEvent( - event_id="$asdasd:localhost", - sender="@foo:bar", - type="custom.avatar.3d.crazy", - ) - events = [event] + presence_states = [ + UserPresenceState( + user_id="@foo:bar", + state="unavailable", + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ), + ] user_filter = self.get_success( self.filtering.get_user_filter( @@ -476,10 +488,10 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = self.get_success(user_filter.filter_presence(events=events)) + results = self.get_success(user_filter.filter_presence(presence_states)) self.assertEqual([], results) - def test_filter_room_state_match(self): + def test_filter_room_state_match(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( @@ -498,7 +510,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): results = self.get_success(user_filter.filter_room_state(events=events)) self.assertEqual(events, results) - def test_filter_room_state_no_match(self): + def test_filter_room_state_no_match(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( @@ -519,7 +531,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): results = self.get_success(user_filter.filter_room_state(events)) self.assertEqual([], results) - def test_filter_rooms(self): + def test_filter_rooms(self) -> None: definition = { "rooms": ["!allowed:example.com", "!excluded:example.com"], "not_rooms": ["!excluded:example.com"], @@ -535,7 +547,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - def test_filter_relations(self): + def test_filter_relations(self) -> None: events = [ # An event without a relation. MockEvent( @@ -551,9 +563,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="org.matrix.custom.event", room_id="!foo:bar", ), - # Non-EventBase objects get passed through. - {}, ] + jsondicts: List[JsonDict] = [{}] # For the following tests we patch the datastore method (intead of injecting # events). This is a bit cheeky, but tests the logic of _check_event_relations. @@ -561,7 +572,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): # Filter for a particular sender. definition = {"related_by_senders": ["@foo:bar"]} - async def events_have_relations(*args, **kwargs): + async def events_have_relations(*args: object, **kwargs: object) -> List[str]: return ["$with_relation"] with patch.object( @@ -572,9 +583,17 @@ class FilteringTestCase(unittest.HomeserverTestCase): Filter(self.hs, definition)._check_event_relations(events) ) ) + # Non-EventBase objects get passed through. + filtered_jsondicts = list( + self.get_success( + Filter(self.hs, definition)._check_event_relations(jsondicts) + ) + ) + self.assertEqual(filtered_events, events[1:]) + self.assertEqual(filtered_jsondicts, [{}]) - def test_add_filter(self): + def test_add_filter(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( @@ -595,7 +614,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ), ) - def test_get_filter(self): + def test_get_filter(self) -> None: user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index b5fd08d437..fa6c1c02ce 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -6,7 +6,7 @@ from tests import unittest class TestRatelimiter(unittest.HomeserverTestCase): - def test_allowed_via_can_do_action(self): + def test_allowed_via_can_do_action(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, @@ -31,7 +31,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) self.assertEqual(20.0, time_allowed) - def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None: appservice = ApplicationService( token="fake_token", id="foo", @@ -64,7 +64,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) self.assertEqual(20.0, time_allowed) - def test_allowed_appservice_via_can_requester_do_action(self): + def test_allowed_appservice_via_can_requester_do_action(self) -> None: appservice = ApplicationService( token="fake_token", id="foo", @@ -97,7 +97,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) self.assertEqual(-1, time_allowed) - def test_allowed_via_ratelimit(self): + def test_allowed_via_ratelimit(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, @@ -120,7 +120,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.ratelimit(None, key="test_id", _time_now_s=10) ) - def test_allowed_via_can_do_action_and_overriding_parameters(self): + def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None: """Test that we can override options of can_do_action that would otherwise fail an action """ @@ -169,7 +169,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) - def test_allowed_via_ratelimit_and_overriding_parameters(self): + def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None: """Test that we can override options of the ratelimit method that would otherwise fail an action """ @@ -204,7 +204,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) ) - def test_pruning(self): + def test_pruning(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, @@ -223,7 +223,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertNotIn("test_id_1", limiter.actions) - def test_db_user_override(self): + def test_db_user_override(self) -> None: """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ @@ -250,7 +250,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): for _ in range(20): self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) - def test_multiple_actions(self): + def test_multiple_actions(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index ff7b349d75..4174a237ec 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs: Any) -> EventBase: kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) -- cgit 1.5.1 From b3bf58a8a5f56674cb0ea0ab6c29aba5775dec52 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 6 Feb 2023 11:29:51 +0000 Subject: Only notify the target of a membership event (#14971) * Only notify the target of a membership event Naughty, but should be a big speedup in large rooms --- changelog.d/14971.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 38 +++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14971.misc (limited to 'synapse') diff --git a/changelog.d/14971.misc b/changelog.d/14971.misc new file mode 100644 index 0000000000..130045a123 --- /dev/null +++ b/changelog.d/14971.misc @@ -0,0 +1 @@ +Improve performance of joining and leaving large rooms with many local users. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 20369f3dfe..f73dceb128 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -142,15 +142,34 @@ class BulkPushRuleEvaluator: Returns: Mapping of user ID to their push rules. """ - # We get the users who may need to be notified by first fetching the - # local users currently in the room, finding those that have push rules, - # and *then* checking which users are actually allowed to see the event. - # - # The alternative is to first fetch all users that were joined at the - # event, but that requires fetching the full state at the event, which - # may be expensive for large rooms with few local users. + # If this is a membership event, only calculate push rules for the target. + # While it's possible for users to configure push rules to respond to such an + # event, in practise nobody does this. At the cost of violating the spec a + # little, we can skip fetching a huge number of push rules in large rooms. + # This helps make joins and leaves faster. + if event.type == EventTypes.Member: + local_users = [] + # We never notify a user about their own actions. This is enforced in + # `_action_for_event_by_user` in the loop over `rules_by_user`, but we + # do the same check here to avoid unnecessary DB queries. + if event.sender != event.state_key and self.hs.is_mine_id(event.state_key): + # Check the target is in the room, to avoid notifying them of + # e.g. a pre-emptive ban. + target_already_in_room = await self.store.check_local_user_in_room( + event.state_key, event.room_id + ) + if target_already_in_room: + local_users = [event.state_key] + else: + # We get the users who may need to be notified by first fetching the + # local users currently in the room, finding those that have push rules, + # and *then* checking which users are actually allowed to see the event. + # + # The alternative is to first fetch all users that were joined at the + # event, but that requires fetching the full state at the event, which + # may be expensive for large rooms with few local users. - local_users = await self.store.get_local_users_in_room(event.room_id) + local_users = await self.store.get_local_users_in_room(event.room_id) # Filter out appservice users. local_users = [ @@ -167,6 +186,9 @@ class BulkPushRuleEvaluator: local_users = list(local_users) local_users.append(invited) + if not local_users: + return {} + rules_by_user = await self.store.bulk_get_push_rules(local_users) logger.debug("Users in room: %s", local_users) -- cgit 1.5.1 From e8269ed391a199bbe0e43efc28c68c98b949b323 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 6 Feb 2023 12:49:06 +0000 Subject: Type hints for tests.appservice (#14990) * Accept a Sequence of events in synapse.appservice This avoids some casts/ignores in the tests I'm about to fixup. It seems that `List[Mock]` is not a subtype of `List[EventBase]`, but `Sequence[Mock]` is a subtype of `Sequence[EventBase]`. So presumably `Mock` is considered a subtype of anything, much like `Any`. * make tests.appservice.test_scheduler pass mypy * Extra hints in tests.appservice.test_scheduler * Extra hints in tests.appservice.test_api * Extra hints in tests.appservice.test_appservice * Disallow untyped defs * Changelog --- changelog.d/14990.misc | 1 + mypy.ini | 4 +- synapse/appservice/__init__.py | 4 +- synapse/appservice/api.py | 14 ++++- synapse/appservice/scheduler.py | 3 +- synapse/storage/databases/main/appservice.py | 14 ++++- tests/appservice/test_api.py | 4 +- tests/appservice/test_appservice.py | 55 ++++++++++++----- tests/appservice/test_scheduler.py | 92 ++++++++++++++++++---------- 9 files changed, 132 insertions(+), 59 deletions(-) create mode 100644 changelog.d/14990.misc (limited to 'synapse') diff --git a/changelog.d/14990.misc b/changelog.d/14990.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/14990.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index a6e37bc377..351b8ccade 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,7 +32,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - |tests/appservice/test_scheduler.py |tests/federation/test_federation_catch_up.py |tests/federation/test_federation_sender.py |tests/http/federation/test_matrix_federation_agent.py @@ -78,6 +77,9 @@ disallow_untyped_defs = True [mypy-tests.app.*] disallow_untyped_defs = True +[mypy-tests.appservice.*] +disallow_untyped_defs = True + [mypy-tests.config.*] disallow_untyped_defs = True diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 65615f50b8..35c330a3c4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -16,7 +16,7 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern, Sequence import attr from netaddr import IPSet @@ -377,7 +377,7 @@ class AppServiceTransaction: self, service: ApplicationService, id: int, - events: List[EventBase], + events: Sequence[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], one_time_keys_count: TransactionOneTimeKeysCount, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index edafd433cd..1a6f69e7d3 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,7 +14,17 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, +) from prometheus_client import Counter from typing_extensions import TypeGuard @@ -259,7 +269,7 @@ class ApplicationServiceApi(SimpleHttpClient): async def push_bulk( self, service: "ApplicationService", - events: List[EventBase], + events: Sequence[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], one_time_keys_count: TransactionOneTimeKeysCount, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 7b562795a3..3a319b0d42 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -57,6 +57,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, ) @@ -364,7 +365,7 @@ class _TransactionController: async def send( self, service: ApplicationService, - events: List[EventBase], + events: Sequence[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index c2c8018ee2..5fb152c4ff 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,17 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Pattern, + Sequence, + Tuple, + cast, +) from synapse.appservice import ( ApplicationService, @@ -257,7 +267,7 @@ class ApplicationServiceTransactionWorkerStore( async def create_appservice_txn( self, service: ApplicationService, - events: List[EventBase], + events: Sequence[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], one_time_keys_count: TransactionOneTimeKeysCount, diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 89ee79396f..9d183b733e 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -29,7 +29,7 @@ URL = "http://mytestservice" class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.api = hs.get_application_service_api() self.service = ApplicationService( id="unique_identifier", @@ -39,7 +39,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): hs_token=TOKEN, ) - def test_query_3pe_authenticates_token(self): + def test_query_3pe_authenticates_token(self) -> None: """ Tests that 3pe queries to the appservice are authenticated with the appservice's token. diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index d4dccfc2f0..dee976356f 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +from typing import Generator from unittest.mock import Mock from twisted.internet import defer @@ -27,7 +28,7 @@ def _regex(regex: str, exclusive: bool = True) -> Namespace: class ApplicationServiceTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.service = ApplicationService( id="unique_identifier", sender="@as:test", @@ -46,7 +47,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks - def test_regex_user_id_prefix_match(self): + def test_regex_user_id_prefix_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( @@ -60,7 +63,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_regex_user_id_prefix_no_match(self): + def test_regex_user_id_prefix_no_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( @@ -74,7 +79,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_regex_room_member_is_checked(self): + def test_regex_room_member_is_checked( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" @@ -90,7 +97,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_regex_room_id_match(self): + def test_regex_room_id_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) @@ -106,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_regex_room_id_no_match(self): + def test_regex_room_id_no_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) @@ -122,7 +133,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_regex_alias_match(self): + def test_regex_alias_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -140,44 +153,46 @@ class ApplicationServiceTestCase(unittest.TestCase): ) ) - def test_non_exclusive_alias(self): + def test_non_exclusive_alias(self) -> None: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) - def test_non_exclusive_room(self): + def test_non_exclusive_room(self) -> None: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org")) - def test_non_exclusive_user(self): + def test_non_exclusive_user(self) -> None: self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=False) ) self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org")) - def test_exclusive_alias(self): + def test_exclusive_alias(self) -> None: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) - def test_exclusive_user(self): + def test_exclusive_user(self) -> None: self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org")) - def test_exclusive_room(self): + def test_exclusive_room(self) -> None: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=True) ) self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org")) @defer.inlineCallbacks - def test_regex_alias_no_match(self): + def test_regex_alias_no_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -196,7 +211,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_regex_multiple_matches(self): + def test_regex_multiple_matches( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) @@ -215,7 +232,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_interested_in_self(self): + def test_interested_in_self( + self, + ) -> Generator["defer.Deferred[object]", object, None]: # make sure invites get through self.service.sender = "@appservice:name" self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) @@ -233,7 +252,9 @@ class ApplicationServiceTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_member_list_match(self): + def test_member_list_match( + self, + ) -> Generator["defer.Deferred[object]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. self.store.get_local_users_in_room = simple_async_mock( diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0a1ae83a2b..febcc1499d 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -11,20 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast from unittest.mock import Mock +from typing_extensions import TypeAlias + from twisted.internet import defer -from synapse.appservice import ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeysCount, + TransactionUnusedFallbackKeys, +) from synapse.appservice.scheduler import ( ApplicationServiceScheduler, _Recoverer, _TransactionController, ) +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable from synapse.server import HomeServer -from synapse.types import DeviceListUpdates +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock from tests import unittest @@ -37,18 +45,18 @@ if TYPE_CHECKING: class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.store = Mock() self.as_api = Mock() self.recoverer = Mock() self.recoverer_fn = Mock(return_value=self.recoverer) self.txnctrl = _TransactionController( - clock=self.clock, store=self.store, as_api=self.as_api + clock=cast(Clock, self.clock), store=self.store, as_api=self.as_api ) self.txnctrl.RECOVERER_CLASS = self.recoverer_fn - def test_single_service_up_txn_sent(self): + def test_single_service_up_txn_sent(self) -> None: # Test: The AS is up and the txn is successfully sent. service = Mock() events = [Mock(), Mock()] @@ -76,7 +84,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed - def test_single_service_down(self): + def test_single_service_down(self) -> None: # Test: The AS is down so it shouldn't push; Recoverers will do it. # It should still make a transaction though. service = Mock() @@ -103,7 +111,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.complete.call_count) # or completed - def test_single_service_up_txn_not_sent(self): + def test_single_service_up_txn_not_sent(self) -> None: # Test: The AS is up and the txn is not sent. A Recoverer is made and # started. service = Mock() @@ -139,26 +147,28 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() self.as_api = Mock() self.store = Mock() self.service = Mock() self.callback = simple_async_mock() self.recoverer = _Recoverer( - clock=self.clock, + clock=cast(Clock, self.clock), as_api=self.as_api, store=self.store, service=self.service, callback=self.callback, ) - def test_recover_single_txn(self): + def test_recover_single_txn(self) -> None: txn = Mock() # return one txn to send, then no more old txns txns = [txn, None] - def take_txn(*args, **kwargs): + def take_txn( + *args: object, **kwargs: object + ) -> "defer.Deferred[Optional[Mock]]": return defer.succeed(txns.pop(0)) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) @@ -177,12 +187,14 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.callback.assert_called_once_with(self.recoverer) self.assertEqual(self.recoverer.service, self.service) - def test_recover_retry_txn(self): + def test_recover_retry_txn(self) -> None: txn = Mock() txns = [txn, None] pop_txn = False - def take_txn(*args, **kwargs): + def take_txn( + *args: object, **kwargs: object + ) -> "defer.Deferred[Optional[Mock]]": if pop_txn: return defer.succeed(txns.pop(0)) else: @@ -214,8 +226,24 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.callback.assert_called_once_with(self.recoverer) +# Corresponds to synapse.appservice.scheduler._TransactionController.send +TxnCtrlArgs: TypeAlias = """ +defer.Deferred[ + Tuple[ + ApplicationService, + Sequence[EventBase], + Optional[List[JsonDict]], + Optional[List[JsonDict]], + Optional[TransactionOneTimeKeysCount], + Optional[TransactionUnusedFallbackKeys], + Optional[DeviceListUpdates], + ] +] +""" + + class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer): + def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() self.txn_ctrl.send = simple_async_mock() @@ -224,7 +252,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.txn_ctrl = self.txn_ctrl self.scheduler.queuer.txn_ctrl = self.txn_ctrl - def test_send_single_event_no_queue(self): + def test_send_single_event_no_queue(self) -> None: # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() @@ -233,8 +261,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service, [event], [], [], None, None, DeviceListUpdates() ) - def test_send_single_event_with_queue(self): - d = defer.Deferred() + def test_send_single_event_with_queue(self) -> None: + d: TxnCtrlArgs = defer.Deferred() self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") @@ -257,22 +285,22 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, self.txn_ctrl.send.call_count) - def test_multiple_service_queues(self): + def test_multiple_service_queues(self) -> None: # Tests that each service has its own queue, and that they don't block # on each other. srv1 = Mock(id=4) - srv_1_defer = defer.Deferred() + srv_1_defer: "defer.Deferred[EventBase]" = defer.Deferred() srv_1_event = Mock(event_id="srv1a") srv_1_event2 = Mock(event_id="srv1b") srv2 = Mock(id=6) - srv_2_defer = defer.Deferred() + srv_2_defer: "defer.Deferred[EventBase]" = defer.Deferred() srv_2_event = Mock(event_id="srv2a") srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - def do_send(*args, **kwargs): + def do_send(*args: object, **kwargs: object) -> "defer.Deferred[EventBase]": return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -297,12 +325,12 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(3, self.txn_ctrl.send.call_count) - def test_send_large_txns(self): - srv_1_defer = defer.Deferred() - srv_2_defer = defer.Deferred() + def test_send_large_txns(self) -> None: + srv_1_defer: "defer.Deferred[EventBase]" = defer.Deferred() + srv_2_defer: "defer.Deferred[EventBase]" = defer.Deferred() send_return_list = [srv_1_defer, srv_2_defer] - def do_send(*args, **kwargs): + def do_send(*args: object, **kwargs: object) -> "defer.Deferred[EventBase]": return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -328,7 +356,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(3, self.txn_ctrl.send.call_count) - def test_send_single_ephemeral_no_queue(self): + def test_send_single_ephemeral_no_queue(self) -> None: # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] @@ -337,7 +365,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service, [], event_list, [], None, None, DeviceListUpdates() ) - def test_send_multiple_ephemeral_no_queue(self): + def test_send_multiple_ephemeral_no_queue(self) -> None: # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] @@ -346,8 +374,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service, [], event_list, [], None, None, DeviceListUpdates() ) - def test_send_single_ephemeral_with_queue(self): - d = defer.Deferred() + def test_send_single_ephemeral_with_queue(self) -> None: + d: TxnCtrlArgs = defer.Deferred() self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] @@ -377,8 +405,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, self.txn_ctrl.send.call_count) - def test_send_large_txns_ephemeral(self): - d = defer.Deferred() + def test_send_large_txns_ephemeral(self) -> None: + d: TxnCtrlArgs = defer.Deferred() self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) # Expect the event to be sent immediately. service = Mock(id=4, name="service") -- cgit 1.5.1 From 5fdc12f482c68e2cdbb78d7db5de2cfe621720d4 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 7 Feb 2023 01:10:54 +0100 Subject: Add `event_stream_ordering` column to membership state tables (#14979) This adds an `event_stream_ordering` column to `current_state_events`, `local_current_membership` and `room_memberships`. Each of these tables is regularly joined with the `events` table to get the stream ordering and denormalising this into each table will yield significant query performance improvements once used. Includes a background job to populate these values from the `events` table. Same idea as https://github.com/matrix-org/synapse/pull/13703. Signed off by Nick @ Beeper (@fizzadar). --- changelog.d/14979.misc | 1 + synapse/storage/databases/main/events.py | 23 +++-- .../storage/databases/main/events_bg_updates.py | 104 ++++++++++++++++++++- synapse/storage/databases/main/events_worker.py | 8 +- .../26membership_tables_event_stream_ordering.sql | 21 +++++ 5 files changed, 146 insertions(+), 11 deletions(-) create mode 100644 changelog.d/14979.misc create mode 100644 synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql (limited to 'synapse') diff --git a/changelog.d/14979.misc b/changelog.d/14979.misc new file mode 100644 index 0000000000..c09911e48d --- /dev/null +++ b/changelog.d/14979.misc @@ -0,0 +1 @@ +Add denormalised event stream ordering column to membership state tables for future use. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1536937b67..b6cce0a7cc 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1147,11 +1147,15 @@ class PersistEventsStore: # been inserted into room_memberships. txn.execute_batch( """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership) - VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + (room_id, type, state_key, event_id, membership, event_stream_ordering) + VALUES ( + ?, ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + ) """, [ - (room_id, key[0], key[1], ev_id, ev_id) + (room_id, key[0], key[1], ev_id, ev_id, ev_id) for key, ev_id in to_insert.items() ], ) @@ -1178,11 +1182,15 @@ class PersistEventsStore: if to_insert: txn.execute_batch( """INSERT INTO local_current_membership - (room_id, user_id, event_id, membership) - VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + (room_id, user_id, event_id, membership, event_stream_ordering) + VALUES ( + ?, ?, ?, + (SELECT membership FROM room_memberships WHERE event_id = ?), + (SELECT stream_ordering FROM events WHERE event_id = ?) + ) """, [ - (room_id, key[1], ev_id, ev_id) + (room_id, key[1], ev_id, ev_id, ev_id) for key, ev_id in to_insert.items() if key[0] == EventTypes.Member and self.is_mine_id(key[1]) ], @@ -1790,6 +1798,7 @@ class PersistEventsStore: table="room_memberships", keys=( "event_id", + "event_stream_ordering", "user_id", "sender", "room_id", @@ -1800,6 +1809,7 @@ class PersistEventsStore: values=[ ( event.event_id, + event.internal_metadata.stream_ordering, event.state_key, event.user_id, event.room_id, @@ -1832,6 +1842,7 @@ class PersistEventsStore: keyvalues={"room_id": event.room_id, "user_id": event.state_key}, values={ "event_id": event.event_id, + "event_stream_ordering": event.internal_metadata.stream_ordering, "membership": event.membership, }, ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index b9d3c36d60..0e81d38cca 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, ca import attr -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -71,6 +71,10 @@ class _BackgroundUpdates: EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" + POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING = ( + "populate_membership_event_stream_ordering" + ) + @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -99,6 +103,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING, + self._populate_membership_event_stream_ordering, + ) self.db_pool.updates.register_background_update_handler( _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts, @@ -1498,3 +1506,97 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return batch_size + + async def _populate_membership_event_stream_ordering( + self, progress: JsonDict, batch_size: int + ) -> int: + def _populate_membership_event_stream_ordering( + txn: LoggingTransaction, + ) -> bool: + + if "max_stream_ordering" in progress: + max_stream_ordering = progress["max_stream_ordering"] + else: + txn.execute("SELECT max(stream_ordering) FROM events") + res = txn.fetchone() + if res is None or res[0] is None: + return True + else: + max_stream_ordering = res[0] + + start = progress.get("stream_ordering", 0) + stop = start + batch_size + + sql = f""" + SELECT room_id, event_id, stream_ordering + FROM events + WHERE + type = '{EventTypes.Member}' + AND stream_ordering >= ? + AND stream_ordering < ? + """ + txn.execute(sql, (start, stop)) + + rows: List[Tuple[str, str, int]] = cast( + List[Tuple[str, str, int]], txn.fetchall() + ) + + event_ids: List[Tuple[str]] = [] + event_stream_orderings: List[Tuple[int]] = [] + + for _, event_id, event_stream_ordering in rows: + event_ids.append((event_id,)) + event_stream_orderings.append((event_stream_ordering,)) + + self.db_pool.simple_update_many_txn( + txn, + table="current_state_events", + key_names=("event_id",), + key_values=event_ids, + value_names=("event_stream_ordering",), + value_values=event_stream_orderings, + ) + + self.db_pool.simple_update_many_txn( + txn, + table="room_memberships", + key_names=("event_id",), + key_values=event_ids, + value_names=("event_stream_ordering",), + value_values=event_stream_orderings, + ) + + # NOTE: local_current_membership has no index on event_id, so only + # the room ID here will reduce the query rows read. + for room_id, event_id, event_stream_ordering in rows: + txn.execute( + """ + UPDATE local_current_membership + SET event_stream_ordering = ? + WHERE room_id = ? AND event_id = ? + """, + (event_stream_ordering, room_id, event_id), + ) + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING, + { + "stream_ordering": stop, + "max_stream_ordering": max_stream_ordering, + }, + ) + + return stop > max_stream_ordering + + finished = await self.db_pool.runInteraction( + "_populate_membership_event_stream_ordering", + _populate_membership_event_stream_ordering, + ) + + if finished: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING + ) + + return batch_size diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d7d08369ca..6d0ef10258 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1779,7 +1779,7 @@ class EventsWorkerStore(SQLBaseStore): txn: LoggingTransaction, ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type," " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," " e.outlier" " FROM events AS e" @@ -1791,10 +1791,10 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" + " WHERE ? < out.event_stream_ordering" + " AND out.event_stream_ordering <= ?" " AND out.instance_name = ?" - " ORDER BY event_stream_ordering ASC" + " ORDER BY out.event_stream_ordering ASC" ) txn.execute(sql, (last_id, current_id, instance_name)) diff --git a/synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql b/synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql new file mode 100644 index 0000000000..7c30a67fc4 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql @@ -0,0 +1,21 @@ +/* Copyright 2022 Beeper + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT; +ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT; +ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_membership_event_stream_ordering', '{}'); -- cgit 1.5.1 From d0fed7a37b8b6ce166cae856fe243757aa7c7294 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Feb 2023 00:20:04 +0000 Subject: Properly typecheck types.http (#14988) * Tweak http types in Synapse AFACIS these are correct, and they make mypy happier on tests.http. * Type hints for test_proxyagent * type hints for test_srv_resolver * test_matrix_federation_agent * tests.http.server._base * tests.http.__init__ * tests.http.test_additional_resource * tests.http.test_client * tests.http.test_endpoint * tests.http.test_matrixfederationclient * tests.http.test_servlet * tests.http.test_simple_client * tests.http.test_site * One fixup in tests.server * Untyped defs * Changelog * Fixup syntax for Python 3.7 * Fix olddeps syntax * Use a twisted IPv4 addr for dummy_address * Fix typo, thanks Sean Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Remove redundant `Optional` --------- Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14988.misc | 1 + mypy.ini | 6 +- synapse/http/client.py | 5 +- synapse/http/proxyagent.py | 3 +- tests/http/__init__.py | 19 ++- .../federation/test_matrix_federation_agent.py | 142 +++++++++++++-------- tests/http/federation/test_srv_resolver.py | 60 +++++---- tests/http/server/_base.py | 2 +- tests/http/test_additional_resource.py | 18 ++- tests/http/test_client.py | 37 ++++-- tests/http/test_endpoint.py | 4 +- tests/http/test_matrixfederationclient.py | 53 ++++---- tests/http/test_proxyagent.py | 103 +++++++++------ tests/http/test_servlet.py | 8 +- tests/http/test_simple_client.py | 14 +- tests/http/test_site.py | 8 +- tests/server.py | 6 +- 17 files changed, 298 insertions(+), 191 deletions(-) create mode 100644 changelog.d/14988.misc (limited to 'synapse') diff --git a/changelog.d/14988.misc b/changelog.d/14988.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/14988.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 93de1c97ea..11e683b704 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,9 +32,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - |tests/http/federation/test_matrix_federation_agent.py - |tests/http/federation/test_srv_resolver.py - |tests/http/test_proxyagent.py |tests/module_api/test_api.py |tests/rest/media/v1/test_media_storage.py |tests/server.py @@ -92,6 +89,9 @@ disallow_untyped_defs = True [mypy-tests.handlers.*] disallow_untyped_defs = True +[mypy-tests.http.*] +disallow_untyped_defs = True + [mypy-tests.logging.*] disallow_untyped_defs = True diff --git a/synapse/http/client.py b/synapse/http/client.py index 4eb740c040..a05f297933 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -44,6 +44,7 @@ from twisted.internet.interfaces import ( IAddress, IDelayedCall, IHostResolution, + IReactorCore, IReactorPluggableNameResolver, IReactorTime, IResolutionReceiver, @@ -226,7 +227,9 @@ class _IPBlacklistingResolver: return recv -@implementer(ISynapseReactor) +# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer +# of IReactorCore seems to keep mypy-zope happier. +@implementer(IReactorCore, ISynapseReactor) class BlacklistingReactorWrapper: """ A Reactor wrapper which will prevent DNS resolution to blacklisted IP diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 18899bc6d1..94ef737b9e 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -38,7 +38,6 @@ from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials -from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -84,7 +83,7 @@ class ProxyAgent(_AgentBase): def __init__( self, reactor: IReactorCore, - proxy_reactor: Optional[ISynapseReactor] = None, + proxy_reactor: Optional[IReactorCore] = None, contextFactory: Optional[IPolicyForHTTPS] = None, connectTimeout: Optional[float] = None, bindAddress: Optional[bytes] = None, diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 093537adef..528cdee34b 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -19,13 +19,15 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.protocols.tls import TLSMemoryBIOProtocol from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 -def get_test_https_policy(): +def get_test_https_policy() -> BrowserLikePolicyForHTTPS: """Get a test IPolicyForHTTPS which trusts the test CA cert Returns: @@ -39,7 +41,7 @@ def get_test_https_policy(): return BrowserLikePolicyForHTTPS(trustRoot=trust_root) -def get_test_ca_cert_file(): +def get_test_ca_cert_file() -> str: """Get the path to the test CA cert The keypair is generated with: @@ -51,7 +53,7 @@ def get_test_ca_cert_file(): return os.path.join(os.path.dirname(__file__), "ca.crt") -def get_test_key_file(): +def get_test_key_file() -> str: """get the path to the test key The key file is made with: @@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory: """An SSL connection creator which returns connections which present a certificate signed by our test CA.""" - def __init__(self, sanlist): + def __init__(self, sanlist: List[bytes]): """ Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert """ self._cert_file = create_test_cert_file(sanlist) - def serverConnectionForTLS(self, tlsProtocol): + def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection: ctx = SSL.Context(SSL.SSLv23_METHOD) ctx.use_certificate_file(self._cert_file) ctx.use_privatekey_file(get_test_key_file()) return Connection(ctx, None) + + +# A dummy address, useful for tests that use FakeTransport and don't care about where +# packets are going to/coming from. +dummy_address = IPv4Address("TCP", "127.0.0.1", 80) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 992d8f94fd..acfdcd3bca 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -14,7 +14,7 @@ import base64 import logging import os -from typing import Iterable, Optional +from typing import Any, Awaitable, Callable, Generator, List, Optional, cast from unittest.mock import Mock, patch import treq @@ -24,14 +24,19 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions -from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.defer import Deferred +from twisted.internet.endpoints import _WrappingProtocol +from twisted.internet.interfaces import ( + IOpenSSLClientConnectionCreator, + IProtocolFactory, +) from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers -from twisted.web.iweb import IPolicyForHTTPS +from twisted.web.iweb import IPolicyForHTTPS, IResponse from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS @@ -42,11 +47,21 @@ from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + LoggingContextOrSentinel, + current_context, +) +from synapse.types import ISynapseReactor from synapse.util.caches.ttlcache import TTLCache from tests import unittest -from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.http import ( + TestServerTLSConnectionFactory, + dummy_address, + get_test_ca_cert_file, +) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.utils import default_config @@ -54,15 +69,17 @@ logger = logging.getLogger(__name__) # Once Async Mocks or lambdas are supported this can go away. -def generate_resolve_service(result): - async def resolve_service(_): +def generate_resolve_service( + result: List[Server], +) -> Callable[[Any], Awaitable[List[Server]]]: + async def resolve_service(_: Any) -> List[Server]: return result return resolve_service class MatrixFederationAgentTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() @@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase): self.tls_factory = FederationPolicyForHTTPS(config) - self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) - self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache( + "test_cache", timer=self.reactor.seconds + ) + self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache( + "test_cache", timer=self.reactor.seconds + ) self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), @@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase): self, client_factory: IProtocolFactory, ssl: bool = True, - expected_sni: bytes = None, - tls_sanlist: Optional[Iterable[bytes]] = None, + expected_sni: Optional[bytes] = None, + tls_sanlist: Optional[List[bytes]] = None, ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection Args: @@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase): if ssl: server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) - server_protocol = server_factory.buildProtocol(None) - + server_protocol = server_factory.buildProtocol(dummy_address) + assert server_protocol is not None # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # HTTP11ClientProtocol) and wire the output of said protocol up to the server via @@ -125,7 +146,8 @@ class MatrixFederationAgentTests(unittest.TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) + client_protocol = client_factory.buildProtocol(dummy_address) + assert isinstance(client_protocol, _WrappingProtocol) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -136,6 +158,7 @@ class MatrixFederationAgentTests(unittest.TestCase): ) if ssl: + assert isinstance(server_protocol, TLSMemoryBIOProtocol) # fish the test server back out of the server-side TLS protocol. http_protocol = server_protocol.wrappedProtocol # grab a hold of the TLS connection, in case it gets torn down @@ -144,6 +167,7 @@ class MatrixFederationAgentTests(unittest.TestCase): http_protocol = server_protocol tls_connection = None + assert isinstance(http_protocol, HTTPChannel) # give the reactor a pump to get the TLS juices flowing (if needed) self.reactor.advance(0) @@ -159,12 +183,14 @@ class MatrixFederationAgentTests(unittest.TestCase): return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri: bytes): + def _make_get_request( + self, uri: bytes + ) -> Generator["Deferred[object]", object, IResponse]: """ Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b"GET", uri) + fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -172,8 +198,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # should have reset logcontext to the sentinel _check_logcontext(SENTINEL_CONTEXT) + fetch_res: IResponse try: - fetch_res = yield fetch_d + fetch_res = yield fetch_d # type: ignore[misc, assignment] return fetch_res except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) @@ -216,7 +243,7 @@ class MatrixFederationAgentTests(unittest.TestCase): request: Request, content: bytes, headers: Optional[dict] = None, - ): + ) -> None: """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -237,16 +264,16 @@ class MatrixFederationAgentTests(unittest.TestCase): because it is created too early during setUp """ return MatrixFederationAgent( - reactor=self.reactor, + reactor=cast(ISynapseReactor, self.reactor), tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) - def test_get(self): + def test_get(self) -> None: """happy-path test of a GET request with an explicit port""" self._do_get() @@ -254,11 +281,11 @@ class MatrixFederationAgentTests(unittest.TestCase): os.environ, {"https_proxy": "proxy.com", "no_proxy": "testserv"}, ) - def test_get_bypass_proxy(self): + def test_get_bypass_proxy(self) -> None: """test of a GET request with an explicit port and bypass proxy""" self._do_get() - def _do_get(self): + def _do_get(self) -> None: """test of a GET request with an explicit port""" self.agent = self._make_agent() @@ -318,7 +345,7 @@ class MatrixFederationAgentTests(unittest.TestCase): @patch.dict( os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} ) - def test_get_via_http_proxy(self): + def test_get_via_http_proxy(self) -> None: """test for federation request through a http proxy""" self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) @@ -326,7 +353,7 @@ class MatrixFederationAgentTests(unittest.TestCase): os.environ, {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, ) - def test_get_via_http_proxy_with_auth(self): + def test_get_via_http_proxy_with_auth(self) -> None: """test for federation request through a http proxy with authentication""" self._do_get_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" @@ -335,7 +362,7 @@ class MatrixFederationAgentTests(unittest.TestCase): @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) - def test_get_via_https_proxy(self): + def test_get_via_https_proxy(self) -> None: """test for federation request through a https proxy""" self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) @@ -343,7 +370,7 @@ class MatrixFederationAgentTests(unittest.TestCase): os.environ, {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, ) - def test_get_via_https_proxy_with_auth(self): + def test_get_via_https_proxy_with_auth(self) -> None: """test for federation request through a https proxy with authentication""" self._do_get_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" @@ -353,7 +380,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a https federation request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: @@ -418,10 +445,12 @@ class MatrixFederationAgentTests(unittest.TestCase): # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() - ).buildProtocol(None) + ).buildProtocol(dummy_address) + assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport + assert proxy_server_transport is not None server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the @@ -451,6 +480,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -491,7 +521,7 @@ class MatrixFederationAgentTests(unittest.TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) - def test_get_ip_address(self): + def test_get_ip_address(self) -> None: """ Test the behaviour when the server name contains an explicit IP (with no port) """ @@ -526,7 +556,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_ipv6_address(self): + def test_get_ipv6_address(self) -> None: """ Test the behaviour when the server name contains an explicit IPv6 address (with no port) @@ -562,7 +592,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_ipv6_address_with_port(self): + def test_get_ipv6_address_with_port(self) -> None: """ Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) @@ -598,7 +628,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_hostname_bad_cert(self): + def test_get_hostname_bad_cert(self) -> None: """ Test the behaviour when the certificate on the server doesn't match the hostname """ @@ -651,7 +681,7 @@ class MatrixFederationAgentTests(unittest.TestCase): failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) - def test_get_ip_address_bad_cert(self): + def test_get_ip_address_bad_cert(self) -> None: """ Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it @@ -684,7 +714,7 @@ class MatrixFederationAgentTests(unittest.TestCase): failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) - def test_get_no_srv_no_well_known(self): + def test_get_no_srv_no_well_known(self) -> None: """ Test the behaviour when the server name has no port, no SRV, and no well-known """ @@ -740,7 +770,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known(self): + def test_get_well_known(self) -> None: """Test the behaviour when the .well-known delegates elsewhere""" self.agent = self._make_agent() @@ -802,7 +832,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) - def test_get_well_known_redirect(self): + def test_get_well_known_redirect(self) -> None: """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ @@ -892,7 +922,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) - def test_get_invalid_well_known(self): + def test_get_invalid_well_known(self) -> None: """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ @@ -945,7 +975,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known_unsigned_cert(self): + def test_get_well_known_unsigned_cert(self) -> None: """Test the behaviour when the .well-known server presents a cert not signed by a CA """ @@ -969,7 +999,7 @@ class MatrixFederationAgentTests(unittest.TestCase): ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( - self.reactor, + cast(ISynapseReactor, self.reactor), Agent(self.reactor, contextFactory=tls_factory), b"test-agent", well_known_cache=self.well_known_cache, @@ -999,7 +1029,7 @@ class MatrixFederationAgentTests(unittest.TestCase): b"_matrix._tcp.testserv" ) - def test_get_hostname_srv(self): + def test_get_hostname_srv(self) -> None: """ Test the behaviour when there is a single SRV record """ @@ -1041,7 +1071,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known_srv(self): + def test_get_well_known_srv(self) -> None: """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ @@ -1101,7 +1131,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_idna_servername(self): + def test_idna_servername(self) -> None: """test the behaviour when the server name has idna chars in""" self.agent = self._make_agent() @@ -1163,7 +1193,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_idna_srv_target(self): + def test_idna_srv_target(self) -> None: """test the behaviour when the target of a SRV record has idna chars""" self.agent = self._make_agent() @@ -1206,7 +1236,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_well_known_cache(self): + def test_well_known_cache(self) -> None: self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = defer.ensureDeferred( @@ -1262,7 +1292,7 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"other-server") - def test_well_known_cache_with_temp_failure(self): + def test_well_known_cache_with_temp_failure(self) -> None: """Test that we refetch well-known before the cache expires, and that it ignores transient errors. """ @@ -1341,7 +1371,7 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) - def test_well_known_too_large(self): + def test_well_known_too_large(self) -> None: """A well-known query that returns a result which is too large should be rejected.""" self.reactor.lookups["testserv"] = "1.2.3.4" @@ -1367,7 +1397,7 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) - def test_srv_fallbacks(self): + def test_srv_fallbacks(self) -> None: """Test that other SRV results are tried if the first one fails.""" self.agent = self._make_agent() @@ -1427,7 +1457,7 @@ class MatrixFederationAgentTests(unittest.TestCase): class TestCachePeriodFromHeaders(unittest.TestCase): - def test_cache_control(self): + def test_cache_control(self) -> None: # uppercase self.assertEqual( _cache_period_from_headers( @@ -1464,7 +1494,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): 0, ) - def test_expires(self): + def test_expires(self) -> None: self.assertEqual( _cache_period_from_headers( Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), @@ -1491,14 +1521,14 @@ class TestCachePeriodFromHeaders(unittest.TestCase): self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) -def _check_logcontext(context): +def _check_logcontext(context: LoggingContextOrSentinel) -> None: current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _wrap_server_factory_for_tls( - factory: IProtocolFactory, sanlist: Iterable[bytes] = None + factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None ) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate @@ -1537,7 +1567,7 @@ def _get_test_protocol_factory() -> IProtocolFactory: return server_factory -def _log_request(request: str): +def _log_request(request: str) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info(f"Completed request {request}") @@ -1547,6 +1577,8 @@ class TrustingTLSPolicyForHTTPS: """An IPolicyForHTTPS which checks that the certificate belongs to the right server, but doesn't check the certificate chain.""" - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 77ce8432ac..7748f56ee6 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Dict, Generator, List, Tuple, cast from unittest.mock import Mock from twisted.internet import defer @@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred from twisted.internet.error import ConnectError from twisted.names import dns, error -from synapse.http.federation.srv_resolver import SrvResolver +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.logging.context import LoggingContext, current_context from tests import unittest @@ -28,7 +28,7 @@ from tests.utils import MockClock class SrvResolverTestCase(unittest.TestCase): - def test_resolve(self): + def test_resolve(self) -> None: dns_client_mock = Mock() service_name = b"test_service.example.com" @@ -38,18 +38,19 @@ class SrvResolverTestCase(unittest.TestCase): type=dns.SRV, payload=dns.Record_SRV(target=host_name) ) - result_deferred = Deferred() + result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock.lookupService.return_value = result_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @defer.inlineCallbacks - def do_lookup(): + def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - result = yield defer.ensureDeferred(resolve_d) + result: List[Server] + result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment] # should have restored our context self.assertIs(current_context(), ctx) @@ -70,7 +71,9 @@ class SrvResolverTestCase(unittest.TestCase): self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks - def test_from_cache_expired_and_dns_fail(self): + def test_from_cache_expired_and_dns_fail( + self, + ) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) @@ -81,10 +84,13 @@ class SrvResolverTestCase(unittest.TestCase): entry.priority = 0 entry.weight = 0 - cache = {service_name: [entry]} + cache = {service_name: [cast(Server, entry)]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -92,7 +98,7 @@ class SrvResolverTestCase(unittest.TestCase): self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks - def test_from_cache(self): + def test_from_cache(self) -> Generator["Deferred[object]", object, None]: clock = MockClock() dns_client_mock = Mock(spec_set=["lookupService"]) @@ -105,12 +111,15 @@ class SrvResolverTestCase(unittest.TestCase): entry.priority = 0 entry.weight = 0 - cache = {service_name: [entry]} + cache = {service_name: [cast(Server, entry)]} resolver = SrvResolver( dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] self.assertFalse(dns_client_mock.lookupService.called) @@ -118,45 +127,48 @@ class SrvResolverTestCase(unittest.TestCase): self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks - def test_empty_cache(self): + def test_empty_cache(self) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) service_name = b"test_service.example.com" - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks - def test_name_error(self): + def test_name_error(self) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) service_name = b"test_service.example.com" - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] self.assertEqual(len(servers), 0) self.assertEqual(len(cache), 0) - def test_disabled_service(self): + def test_disabled_service(self) -> None: """ test the behaviour when there is a single record which is ".". """ service_name = b"test_service.example.com" - lookup_deferred = Deferred() + lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) # Old versions of Twisted don't have an ensureDeferred in failureResultOf. @@ -173,16 +185,16 @@ class SrvResolverTestCase(unittest.TestCase): self.failureResultOf(resolve_d, ConnectError) - def test_non_srv_answer(self): + def test_non_srv_answer(self) -> None: """ test the behaviour when the dns server gives us a spurious non-SRV response """ service_name = b"test_service.example.com" - lookup_deferred = Deferred() + lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) # Old versions of Twisted don't have an ensureDeferred in successResultOf. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 5071f83574..36472e57a8 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str: return method_name -def _hash_stack(stack: List[inspect.FrameInfo]): +def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]: """Turns a stack into a hashable value that can be put into a set.""" return tuple(_format_stack_frame(frame) for frame in stack) diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 391196425c..ec6aacf235 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -11,28 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from twisted.web.server import Request from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase class _AsyncTestCustomEndpoint: - def __init__(self, config, module_api): + def __init__(self, config: JsonDict, module_api: Any) -> None: pass - async def handle_request(self, request): + async def handle_request(self, request: Request) -> None: + assert isinstance(request, SynapseRequest) respond_with_json(request, 200, {"some_key": "some_value_async"}) class _SyncTestCustomEndpoint: - def __init__(self, config, module_api): + def __init__(self, config: JsonDict, module_api: Any) -> None: pass - async def handle_request(self, request): + async def handle_request(self, request: Request) -> None: + assert isinstance(request, SynapseRequest) respond_with_json(request, 200, {"some_key": "some_value_sync"}) @@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase): and async handlers. """ - def test_async(self): + def test_async(self) -> None: handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) @@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) - def test_sync(self): + def test_sync(self) -> None: handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 7e2f2a01cc..9cfe1ad0de 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -13,10 +13,12 @@ # limitations under the License. from io import BytesIO +from typing import Tuple, Union from unittest.mock import Mock from netaddr import IPSet +from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol @@ -28,6 +30,7 @@ from synapse.http.client import ( BlacklistingAgentWrapper, BlacklistingReactorWrapper, BodyExceededMaxSize, + _DiscardBodyWithMaxSizeProtocol, read_body_with_max_size, ) @@ -36,7 +39,9 @@ from tests.unittest import TestCase class ReadBodyWithMaxSizeTests(TestCase): - def _build_response(self, length=UNKNOWN_LENGTH): + def _build_response( + self, length: Union[int, str] = UNKNOWN_LENGTH + ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: """Start reading the body, returns the response, result and proto""" response = Mock(length=length) result = BytesIO() @@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase): return result, deferred, protocol - def _assert_error(self, deferred, protocol): + def _assert_error( + self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol + ) -> None: """Ensure that the expected error is received.""" - self.assertIsInstance(deferred.result, Failure) + assert isinstance(deferred.result, Failure) self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) - protocol.transport.abortConnection.assert_called_once() + assert protocol.transport is not None + # type-ignore: presumably abortConnection has been replaced with a Mock. + protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined] - def _cleanup_error(self, deferred): + def _cleanup_error(self, deferred: "Deferred[int]") -> None: """Ensure that the error in the Deferred is handled gracefully.""" called = [False] - def errback(f): + def errback(f: Failure) -> None: called[0] = True deferred.addErrback(errback) self.assertTrue(called[0]) - def test_no_error(self): + def test_no_error(self) -> None: """A response that is NOT too large.""" result, deferred, protocol = self._build_response() @@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self.assertEqual(result.getvalue(), b"12345") self.assertEqual(deferred.result, 5) - def test_too_large(self): + def test_too_large(self) -> None: """A response which is too large raises an exception.""" result, deferred, protocol = self._build_response() @@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self._assert_error(deferred, protocol) self._cleanup_error(deferred) - def test_multiple_packets(self): + def test_multiple_packets(self) -> None: """Data should be accumulated through mutliple packets.""" result, deferred, protocol = self._build_response() @@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self.assertEqual(result.getvalue(), b"1234") self.assertEqual(deferred.result, 4) - def test_additional_data(self): + def test_additional_data(self) -> None: """A connection can receive data after being closed.""" result, deferred, protocol = self._build_response() @@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self._assert_error(deferred, protocol) self._cleanup_error(deferred) - def test_content_length(self): + def test_content_length(self) -> None: """The body shouldn't be read (at all) if the Content-Length header is too large.""" result, deferred, protocol = self._build_response(length=10) @@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase): class BlacklistingAgentTest(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, self.clock = get_clock() self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" @@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase): self.ip_whitelist = IPSet([self.allowed_ip.decode()]) self.ip_blacklist = IPSet(["5.0.0.0/8"]) - def test_reactor(self): + def test_reactor(self) -> None: """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" agent = Agent( BlacklistingReactorWrapper( @@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase): response = self.successResultOf(d) self.assertEqual(response.code, 200) - def test_agent(self): + def test_agent(self) -> None: """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" agent = BlacklistingAgentWrapper( Agent(self.reactor), diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index a801f002a0..8c18e56881 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -17,7 +17,7 @@ from tests import unittest class ServerNameTestCase(unittest.TestCase): - def test_parse_server_name(self): + def test_parse_server_name(self) -> None: test_data = { "localhost": ("localhost", None), "my-example.com:1234": ("my-example.com", 1234), @@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase): for i, o in test_data.items(): self.assertEqual(parse_server_name(i), o) - def test_validate_bad_server_names(self): + def test_validate_bad_server_names(self) -> None: test_data = [ "", # empty "localhost:http", # non-numeric port diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index be9eaf34e8..fdd22a8e94 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Generator from unittest.mock import Mock from netaddr import IPSet from parameterized import parameterized from twisted.internet import defer -from twisted.internet.defer import TimeoutError +from twisted.internet.defer import Deferred, TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from twisted.web.client import ResponseNeverReceived from twisted.web.http import HTTPChannel @@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + LoggingContextOrSentinel, + current_context, +) +from synapse.server import HomeServer +from synapse.util import Clock from tests.server import FakeTransport from tests.unittest import HomeserverTestCase -def check_logcontext(context): +def check_logcontext(context: LoggingContextOrSentinel) -> None: current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4" - def test_client_get(self): + def test_client_get(self) -> None: """ happy-path test of a GET request """ @defer.inlineCallbacks - def do_request(): + def do_request() -> Generator["Deferred[object]", object, object]: with LoggingContext("one") as context: fetch_d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar") @@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase): # check the response is as expected self.assertEqual(res, {"a": 1}) - def test_dns_error(self): + def test_dns_error(self) -> None: """ If the DNS lookup returns an error, it will bubble up. """ @@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) - def test_client_connection_refused(self): + def test_client_connection_refused(self) -> None: d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) ) @@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIs(f.value.inner_exception, e) - def test_client_never_connect(self): + def test_client_never_connect(self) -> None: """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. @@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase): f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) - def test_client_connect_no_response(self): + def test_client_connect_no_response(self) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. @@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) - def test_client_ip_range_blacklist(self): + def test_client_ip_range_blacklist(self) -> None: """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist @@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase): f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) - def test_client_gets_headers(self): + def test_client_gets_headers(self) -> None: """ Once the client gets the headers, _request returns successfully. """ @@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase): self.assertEqual(r.code, 200) @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) - def test_timeout_reading_body(self, method_name: str): + def test_timeout_reading_body(self, method_name: str) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a RequestSendFailed with can_retry. @@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase): self.assertTrue(f.value.can_retry) self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) - def test_client_requires_trailing_slashes(self): + def test_client_requires_trailing_slashes(self) -> None: """ If a connection is made to a client but the client rejects it due to requiring a trailing slash. We need to retry the request with a @@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase): r = self.successResultOf(d) self.assertEqual(r, {}) - def test_client_does_not_retry_on_400_plus(self): + def test_client_does_not_retry_on_400_plus(self) -> None: """ Another test for trailing slashes but now test that we don't retry on trailing slashes on a non-400/M_UNRECOGNIZED response. @@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase): # We should get a 404 failure response self.failureResultOf(d) - def test_client_sends_body(self): + def test_client_sends_body(self) -> None: defer.ensureDeferred( self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} @@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase): content = request.content.read() self.assertEqual(content, b'{"a":"b"}') - def test_closes_connection(self): + def test_closes_connection(self) -> None: """Check that the client closes unused HTTP connections""" d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) @@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase): self.assertTrue(conn.disconnecting) @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)]) - def test_json_error(self, return_value): + def test_json_error(self, return_value: bytes) -> None: """ Test what happens if invalid JSON is returned from the remote endpoint. """ @@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase): f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) - def test_too_big(self): + def test_too_big(self) -> None: """ Test what happens if a huge response is returned from the remote endpoint. """ diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 2db77c6a73..a817940730 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -14,7 +14,7 @@ import base64 import logging import os -from typing import Iterable, Optional +from typing import List, Optional from unittest.mock import patch import treq @@ -22,7 +22,11 @@ from netaddr import IPSet from parameterized import parameterized from twisted.internet import interfaces # noqa: F401 -from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint +from twisted.internet.endpoints import ( + HostnameEndpoint, + _WrapperEndpoint, + _WrappingProtocol, +) from twisted.internet.interfaces import IProtocol, IProtocolFactory from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol @@ -32,7 +36,11 @@ from synapse.http.client import BlacklistingReactorWrapper from synapse.http.connectproxyclient import ProxyCredentials from synapse.http.proxyagent import ProxyAgent, parse_proxy -from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.http import ( + TestServerTLSConnectionFactory, + dummy_address, + get_test_https_policy, +) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase @@ -183,7 +191,7 @@ class ProxyParserTests(TestCase): expected_hostname: bytes, expected_port: int, expected_credentials: Optional[bytes], - ): + ) -> None: """ Tests that a given proxy URL will be broken into the components. Args: @@ -209,7 +217,7 @@ class ProxyParserTests(TestCase): class MatrixFederationAgentTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() def _make_connection( @@ -218,7 +226,7 @@ class MatrixFederationAgentTests(TestCase): server_factory: IProtocolFactory, ssl: bool = False, expected_sni: Optional[bytes] = None, - tls_sanlist: Optional[Iterable[bytes]] = None, + tls_sanlist: Optional[List[bytes]] = None, ) -> IProtocol: """Builds a test server, and completes the outgoing client connection @@ -244,7 +252,8 @@ class MatrixFederationAgentTests(TestCase): if ssl: server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) - server_protocol = server_factory.buildProtocol(None) + server_protocol = server_factory.buildProtocol(dummy_address) + assert server_protocol is not None # now, tell the client protocol factory to build the client protocol, # and wire the output of said protocol up to the server via @@ -252,7 +261,8 @@ class MatrixFederationAgentTests(TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) + client_protocol = client_factory.buildProtocol(dummy_address) + assert client_protocol is not None client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -263,6 +273,7 @@ class MatrixFederationAgentTests(TestCase): ) if ssl: + assert isinstance(server_protocol, TLSMemoryBIOProtocol) http_protocol = server_protocol.wrappedProtocol tls_connection = server_protocol._tlsConnection else: @@ -288,7 +299,7 @@ class MatrixFederationAgentTests(TestCase): scheme: bytes, hostname: bytes, path: bytes, - ): + ) -> None: """Runs a test case for a direct connection not going through a proxy. Args: @@ -319,6 +330,7 @@ class MatrixFederationAgentTests(TestCase): ssl=is_https, expected_sni=hostname if is_https else None, ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -339,34 +351,34 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - def test_http_request(self): + def test_http_request(self) -> None: agent = ProxyAgent(self.reactor) self._test_request_direct_connection(agent, b"http", b"test.com", b"") - def test_https_request(self): + def test_https_request(self) -> None: agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") - def test_http_request_use_proxy_empty_environment(self): + def test_http_request_use_proxy_empty_environment(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"}) - def test_http_request_via_uppercase_no_proxy(self): + def test_http_request_via_uppercase_no_proxy(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"} ) - def test_http_request_via_no_proxy(self): + def test_http_request_via_no_proxy(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"} ) - def test_https_request_via_no_proxy(self): + def test_https_request_via_no_proxy(self) -> None: agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -375,12 +387,12 @@ class MatrixFederationAgentTests(TestCase): self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) - def test_http_request_via_no_proxy_star(self): + def test_http_request_via_no_proxy_star(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) - def test_https_request_via_no_proxy_star(self): + def test_https_request_via_no_proxy_star(self) -> None: agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -389,7 +401,7 @@ class MatrixFederationAgentTests(TestCase): self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) - def test_http_request_via_proxy(self): + def test_http_request_via_proxy(self) -> None: """ Tests that requests can be made through a proxy. """ @@ -401,7 +413,7 @@ class MatrixFederationAgentTests(TestCase): os.environ, {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, ) - def test_http_request_via_proxy_with_auth(self): + def test_http_request_via_proxy_with_auth(self) -> None: """ Tests that authenticated requests can be made through a proxy. """ @@ -412,7 +424,7 @@ class MatrixFederationAgentTests(TestCase): @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) - def test_http_request_via_https_proxy(self): + def test_http_request_via_https_proxy(self) -> None: self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=None ) @@ -424,13 +436,13 @@ class MatrixFederationAgentTests(TestCase): "no_proxy": "unused.com", }, ) - def test_http_request_via_https_proxy_with_auth(self): + def test_http_request_via_https_proxy_with_auth(self) -> None: self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): + def test_https_request_via_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=None @@ -440,7 +452,7 @@ class MatrixFederationAgentTests(TestCase): os.environ, {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, ) - def test_https_request_via_proxy_with_auth(self): + def test_https_request_via_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" @@ -449,7 +461,7 @@ class MatrixFederationAgentTests(TestCase): @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) - def test_https_request_via_https_proxy(self): + def test_https_request_via_https_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=None @@ -459,7 +471,7 @@ class MatrixFederationAgentTests(TestCase): os.environ, {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, ) - def test_https_request_via_https_proxy_with_auth(self): + def test_https_request_via_https_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" @@ -469,7 +481,7 @@ class MatrixFederationAgentTests(TestCase): self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: @@ -501,6 +513,7 @@ class MatrixFederationAgentTests(TestCase): tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -542,7 +555,7 @@ class MatrixFederationAgentTests(TestCase): self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: @@ -606,10 +619,12 @@ class MatrixFederationAgentTests(TestCase): # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() - ).buildProtocol(None) + ).buildProtocol(dummy_address) + assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport + assert proxy_server_transport is not None server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the @@ -644,6 +659,7 @@ class MatrixFederationAgentTests(TestCase): # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -667,7 +683,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) - def test_http_request_via_proxy_with_blacklist(self): + def test_http_request_via_proxy_with_blacklist(self) -> None: # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper( @@ -691,6 +707,7 @@ class MatrixFederationAgentTests(TestCase): http_server = self._make_connection( client_factory, _get_test_protocol_factory() ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -712,7 +729,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"}) - def test_https_request_via_uppercase_proxy_with_blacklist(self): + def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None: # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper( @@ -737,11 +754,15 @@ class MatrixFederationAgentTests(TestCase): proxy_server = self._make_connection( client_factory, _get_test_protocol_factory() ) + assert isinstance(proxy_server, HTTPChannel) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport + assert isinstance(s2c_transport, FakeTransport) client_protocol = s2c_transport.other + assert isinstance(client_protocol, _WrappingProtocol) c2s_transport = client_protocol.transport + assert isinstance(c2s_transport, FakeTransport) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -762,8 +783,10 @@ class MatrixFederationAgentTests(TestCase): # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) - ssl_protocol = ssl_factory.buildProtocol(None) + ssl_protocol = ssl_factory.buildProtocol(dummy_address) + assert isinstance(ssl_protocol, TLSMemoryBIOProtocol) http_server = ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol) @@ -797,28 +820,28 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) - def test_proxy_with_no_scheme(self): + def test_proxy_with_no_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) - def test_proxy_with_unsupported_scheme(self): + def test_proxy_with_unsupported_scheme(self) -> None: with self.assertRaises(ValueError): ProxyAgent(self.reactor, use_proxy=True) @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) - def test_proxy_with_http_scheme(self): + def test_proxy_with_http_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) - def test_proxy_with_https_scheme(self): + def test_proxy_with_https_scheme(self) -> None: https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) + assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) self.assertEqual( https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" ) @@ -828,7 +851,7 @@ class MatrixFederationAgentTests(TestCase): def _wrap_server_factory_for_tls( - factory: IProtocolFactory, sanlist: Iterable[bytes] = None + factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None ) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory @@ -865,6 +888,6 @@ def _get_test_protocol_factory() -> IProtocolFactory: return server_factory -def _log_request(request: str): +def _log_request(request: str) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info(f"Completed request {request}") diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 46166292fe..c8d215b6dc 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -14,7 +14,7 @@ import json from http import HTTPStatus from io import BytesIO -from typing import Tuple +from typing import Tuple, Union from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -33,7 +33,7 @@ from tests import unittest from tests.http.server._base import test_disconnect -def make_request(content): +def make_request(content: Union[bytes, JsonDict]) -> Mock: """Make an object that acts enough like a request.""" request = Mock(spec=["method", "uri", "content"]) @@ -47,7 +47,7 @@ def make_request(content): class TestServletUtils(unittest.TestCase): - def test_parse_json_value(self): + def test_parse_json_value(self) -> None: """Basic tests for parse_json_value_from_request.""" # Test round-tripping. obj = {"foo": 1} @@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase): with self.assertRaises(SynapseError): parse_json_value_from_request(make_request(b'{"foo": Infinity}')) - def test_parse_json_object(self): + def test_parse_json_object(self) -> None: """Basic tests for parse_json_object_from_request.""" # Test empty. result = parse_json_object_from_request( diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py index c85a3665c1..010601da4b 100644 --- a/tests/http/test_simple_client.py +++ b/tests/http/test_simple_client.py @@ -17,22 +17,24 @@ from netaddr import IPSet from twisted.internet import defer from twisted.internet.error import DNSLookupError +from twisted.test.proto_helpers import MemoryReactor from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class SimpleHttpClientTests(HomeserverTestCase): - def prepare(self, reactor, clock, hs: "HomeServer"): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None: # Add a DNS entry for a test server self.reactor.lookups["testserv"] = "1.2.3.4" self.cl = hs.get_simple_http_client() - def test_dns_error(self): + def test_dns_error(self) -> None: """ If the DNS lookup returns an error, it will bubble up. """ @@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase): f = self.failureResultOf(d) self.assertIsInstance(f.value, DNSLookupError) - def test_client_connection_refused(self): + def test_client_connection_refused(self) -> None: d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) self.pump() @@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase): self.assertIs(f.value, e) - def test_client_never_connect(self): + def test_client_never_connect(self) -> None: """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. @@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestTimedOutError) - def test_client_connect_no_response(self): + def test_client_connect_no_response(self) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. @@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestTimedOutError) - def test_client_ip_range_blacklist(self): + def test_client_ip_range_blacklist(self) -> None: """Ensure that Synapse does not try to connect to blacklisted IPs""" # Add some DNS entries we'll blacklist diff --git a/tests/http/test_site.py b/tests/http/test_site.py index b2dbf76d33..9a78fede92 100644 --- a/tests/http/test_site.py +++ b/tests/http/test_site.py @@ -13,18 +13,20 @@ # limitations under the License. from twisted.internet.address import IPv6Address -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from synapse.app.homeserver import SynapseHomeServer +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class SynapseRequestTestCase(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer) - def test_large_request(self): + def test_large_request(self) -> None: """overlarge HTTP requests should be rejected""" self.hs.start_listening() diff --git a/tests/server.py b/tests/server.py index b1730fcc8d..237bcad8ba 100644 --- a/tests/server.py +++ b/tests/server.py @@ -70,7 +70,7 @@ from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine -from synapse.types import JsonDict +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests.utils import ( @@ -401,7 +401,9 @@ def make_request( return channel -@implementer(IReactorPluggableNameResolver) +# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly +# marking this as an implementer of the latter seems to keep mypy-zope happier. +@implementer(IReactorPluggableNameResolver, ISynapseReactor) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. -- cgit 1.5.1 From 5b55c32d610b2baec8622f0418519b130ab4fa30 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Feb 2023 06:56:09 -0500 Subject: Add tests for using _flatten_dict with an event. (#15002) --- changelog.d/15002.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 13 +++---- tests/push/test_push_rule_evaluator.py | 63 +++++++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 changelog.d/15002.misc (limited to 'synapse') diff --git a/changelog.d/15002.misc b/changelog.d/15002.misc new file mode 100644 index 0000000000..68ac8335fc --- /dev/null +++ b/changelog.d/15002.misc @@ -0,0 +1 @@ +Add tests for `_flatten_dict`. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f73dceb128..d9c0a98f44 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -36,7 +36,7 @@ from synapse.api.constants import ( Membership, RelationTypes, ) -from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion +from synapse.api.room_versions import PushRuleRoomFlag from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -405,7 +405,7 @@ class BulkPushRuleEvaluator: room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( - _flatten_dict(event, room_version=event.room_version), + _flatten_dict(event), has_mentions, user_mentions, room_mention, @@ -491,7 +491,6 @@ StateGroup = Union[object, int] def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], - room_version: Optional[RoomVersion] = None, prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -511,7 +510,6 @@ def _flatten_dict( Args: d: The event or content to continue flattening. - room_version: The room version object. prefix: The key prefix (from outer dictionaries). result: The result to mutate. @@ -531,14 +529,13 @@ def _flatten_dict( # `room_version` should only ever be set when looking at the top level of an event if ( - room_version is not None - and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features - and isinstance(d, EventBase) + isinstance(d, EventBase) + and PushRuleRoomFlag.EXTENSIBLE_EVENTS in d.room_version.msc3931_push_features ): # Room supports extensible events: replace `content.body` with the plain text # representation from `m.markup`, as per MSC1767. markup = d.get("content").get("m.markup") - if room_version.identifier.startswith("org.matrix.msc1767."): + if d.room_version.identifier.startswith("org.matrix.msc1767."): markup = d.get("content").get("org.matrix.msc1767.markup") if markup is not None and isinstance(markup, list): text = "" diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 7c430c4ecb..da33423871 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, make_event_from_dict from synapse.push.bulk_push_rule_evaluator import _flatten_dict from synapse.push.httppusher import tweaks_for_actions from synapse.rest import admin @@ -60,6 +60,67 @@ class FlattenDictTestCase(unittest.TestCase): } self.assertEqual({"woo": "woo"}, _flatten_dict(input)) + def test_event(self) -> None: + """Events can also be flattened.""" + event = make_event_from_dict( + { + "room_id": "!test:test", + "type": "m.room.message", + "sender": "@alice:test", + "content": { + "msgtype": "m.text", + "body": "Hello world!", + "format": "org.matrix.custom.html", + "formatted_body": "

Hello world!

", + }, + }, + room_version=RoomVersions.V8, + ) + expected = { + "content.msgtype": "m.text", + "content.body": "hello world!", + "content.format": "org.matrix.custom.html", + "content.formatted_body": "

hello world!

", + "room_id": "!test:test", + "sender": "@alice:test", + "type": "m.room.message", + } + self.assertEqual(expected, _flatten_dict(event)) + + def test_extensible_events(self) -> None: + """Extensible events has compatibility behaviour.""" + event_dict = { + "room_id": "!test:test", + "type": "m.room.message", + "sender": "@alice:test", + "content": { + "org.matrix.msc1767.markup": [ + {"mimetype": "text/plain", "body": "Hello world!"}, + {"mimetype": "text/html", "body": "

Hello world!

"}, + ] + }, + } + + # For a current room version, there's no special behavior. + event = make_event_from_dict(event_dict, room_version=RoomVersions.V8) + expected = { + "room_id": "!test:test", + "sender": "@alice:test", + "type": "m.room.message", + } + self.assertEqual(expected, _flatten_dict(event)) + + # For a room version with extensible events, they parse out the text/plain + # to a content.body property. + event = make_event_from_dict(event_dict, room_version=RoomVersions.MSC1767v10) + expected = { + "content.body": "hello world!", + "room_id": "!test:test", + "sender": "@alice:test", + "type": "m.room.message", + } + self.assertEqual(expected, _flatten_dict(event)) + class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( -- cgit 1.5.1 From 2dff93099b5aa7e213da76a9c4b3de84385b58e1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Feb 2023 15:24:44 +0000 Subject: Typecheck tests.rest.media.v1.test_media_storage (#15008) * Fix MediaStorage type hint * Typecheck tests.rest.media.v1.test_media_storage * Changelog * Remove assert and make the comment succinct * Fix syntax for olddeps --- changelog.d/15008.misc | 1 + mypy.ini | 1 - synapse/rest/media/v1/media_storage.py | 7 ++--- tests/rest/media/v1/test_media_storage.py | 49 +++++++++++++++++++------------ 4 files changed, 35 insertions(+), 23 deletions(-) create mode 100644 changelog.d/15008.misc (limited to 'synapse') diff --git a/changelog.d/15008.misc b/changelog.d/15008.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/15008.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 0efafb26b6..4598002c4a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,7 +33,6 @@ exclude = (?x) |synapse/storage/schema/ |tests/module_api/test_api.py - |tests/rest/media/v1/test_media_storage.py |tests/server.py )$ diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a5c3de192f..db25848744 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -46,10 +46,9 @@ from ._base import FileInfo, Responder from .filepath import MediaFilePaths if TYPE_CHECKING: + from synapse.rest.media.v1.storage_provider import StorageProvider from synapse.server import HomeServer - from .storage_provider import StorageProviderWrapper - logger = logging.getLogger(__name__) @@ -68,7 +67,7 @@ class MediaStorage: hs: "HomeServer", local_media_directory: str, filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProviderWrapper"], + storage_providers: Sequence["StorageProvider"], ): self.hs = hs self.reactor = hs.get_reactor() @@ -360,7 +359,7 @@ class ReadableFileWrapper: clock: Clock path: str - async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None: + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: """Reads the file in chunks and calls the callback with each chunk.""" with open(self.path, "rb") as file: diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index d18fc13c21..17a3b06a8e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -16,7 +16,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Any, BinaryIO, Dict, List, Optional, Union +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union from unittest.mock import Mock from urllib import parse @@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable from synapse.module_api import ModuleApi from synapse.rest import admin @@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.server import HomeServer -from synapse.types import RoomAlias +from synapse.types import JsonDict, RoomAlias from synapse.util import Clock from tests import unittest @@ -201,36 +202,46 @@ class _TestImage: ], ) class MediaRepoTests(unittest.HomeserverTestCase): - + test_image: ClassVar[_TestImage] hijack_auth = True user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches = [] + self.fetches: List[ + Tuple[ + "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", + str, + str, + Optional[QueryParams], + ] + ] = [] def get_file( destination: str, path: str, output_stream: BinaryIO, - args: Optional[Dict[str, Union[str, List[str]]]] = None, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, max_size: Optional[int] = None, - ) -> Deferred: - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + ignore_backoff: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": + """A mock for MatrixFederationHttpClient.get_file.""" - def write_to(r): + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) return response - d = Deferred() - d.addCallback(write_to) + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallback(write_to) + return make_deferred_yieldable(d_after_callback) + # Mock out the homeserver's MatrixFederationHttpClient client = Mock() client.get_file = get_file @@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): # Synapse should regenerate missing thumbnails. origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) + assert info is not None file_id = info["filesystem_id"] thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( @@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension}", + "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", }, { "thumbnail_width": 32, @@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension}", + "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", }, ], - file_id=f"image{self.test_image.extension}", + file_id=f"image{self.test_image.extension.decode()}", url_cache=None, server_name=None, ) @@ -637,6 +649,7 @@ class TestSpamCheckerLegacy: self.config = config self.api = api + @staticmethod def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config @@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"]]: + ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) -- cgit 1.5.1 From 9cd7610f86ab5051c9365dd38d1eec405a5f8ca6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Feb 2023 15:26:55 +0000 Subject: Revert "Add `event_stream_ordering` column to membership state tables (#14979)" This reverts commit 5fdc12f482c68e2cdbb78d7db5de2cfe621720d4. --- synapse/storage/databases/main/events.py | 23 ++--- .../storage/databases/main/events_bg_updates.py | 104 +-------------------- synapse/storage/databases/main/events_worker.py | 8 +- .../26membership_tables_event_stream_ordering.sql | 21 ----- 4 files changed, 11 insertions(+), 145 deletions(-) delete mode 100644 synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql (limited to 'synapse') diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b6cce0a7cc..1536937b67 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1147,15 +1147,11 @@ class PersistEventsStore: # been inserted into room_memberships. txn.execute_batch( """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership, event_stream_ordering) - VALUES ( - ?, ?, ?, ?, - (SELECT membership FROM room_memberships WHERE event_id = ?), - (SELECT stream_ordering FROM events WHERE event_id = ?) - ) + (room_id, type, state_key, event_id, membership) + VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) """, [ - (room_id, key[0], key[1], ev_id, ev_id, ev_id) + (room_id, key[0], key[1], ev_id, ev_id) for key, ev_id in to_insert.items() ], ) @@ -1182,15 +1178,11 @@ class PersistEventsStore: if to_insert: txn.execute_batch( """INSERT INTO local_current_membership - (room_id, user_id, event_id, membership, event_stream_ordering) - VALUES ( - ?, ?, ?, - (SELECT membership FROM room_memberships WHERE event_id = ?), - (SELECT stream_ordering FROM events WHERE event_id = ?) - ) + (room_id, user_id, event_id, membership) + VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) """, [ - (room_id, key[1], ev_id, ev_id, ev_id) + (room_id, key[1], ev_id, ev_id) for key, ev_id in to_insert.items() if key[0] == EventTypes.Member and self.is_mine_id(key[1]) ], @@ -1798,7 +1790,6 @@ class PersistEventsStore: table="room_memberships", keys=( "event_id", - "event_stream_ordering", "user_id", "sender", "room_id", @@ -1809,7 +1800,6 @@ class PersistEventsStore: values=[ ( event.event_id, - event.internal_metadata.stream_ordering, event.state_key, event.user_id, event.room_id, @@ -1842,7 +1832,6 @@ class PersistEventsStore: keyvalues={"room_id": event.room_id, "user_id": event.state_key}, values={ "event_id": event.event_id, - "event_stream_ordering": event.internal_metadata.stream_ordering, "membership": event.membership, }, ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 0e81d38cca..b9d3c36d60 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, ca import attr -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, RelationTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -71,10 +71,6 @@ class _BackgroundUpdates: EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" - POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING = ( - "populate_membership_event_stream_ordering" - ) - @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -103,10 +99,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self.db_pool.updates.register_background_update_handler( - _BackgroundUpdates.POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING, - self._populate_membership_event_stream_ordering, - ) self.db_pool.updates.register_background_update_handler( _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts, @@ -1506,97 +1498,3 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return batch_size - - async def _populate_membership_event_stream_ordering( - self, progress: JsonDict, batch_size: int - ) -> int: - def _populate_membership_event_stream_ordering( - txn: LoggingTransaction, - ) -> bool: - - if "max_stream_ordering" in progress: - max_stream_ordering = progress["max_stream_ordering"] - else: - txn.execute("SELECT max(stream_ordering) FROM events") - res = txn.fetchone() - if res is None or res[0] is None: - return True - else: - max_stream_ordering = res[0] - - start = progress.get("stream_ordering", 0) - stop = start + batch_size - - sql = f""" - SELECT room_id, event_id, stream_ordering - FROM events - WHERE - type = '{EventTypes.Member}' - AND stream_ordering >= ? - AND stream_ordering < ? - """ - txn.execute(sql, (start, stop)) - - rows: List[Tuple[str, str, int]] = cast( - List[Tuple[str, str, int]], txn.fetchall() - ) - - event_ids: List[Tuple[str]] = [] - event_stream_orderings: List[Tuple[int]] = [] - - for _, event_id, event_stream_ordering in rows: - event_ids.append((event_id,)) - event_stream_orderings.append((event_stream_ordering,)) - - self.db_pool.simple_update_many_txn( - txn, - table="current_state_events", - key_names=("event_id",), - key_values=event_ids, - value_names=("event_stream_ordering",), - value_values=event_stream_orderings, - ) - - self.db_pool.simple_update_many_txn( - txn, - table="room_memberships", - key_names=("event_id",), - key_values=event_ids, - value_names=("event_stream_ordering",), - value_values=event_stream_orderings, - ) - - # NOTE: local_current_membership has no index on event_id, so only - # the room ID here will reduce the query rows read. - for room_id, event_id, event_stream_ordering in rows: - txn.execute( - """ - UPDATE local_current_membership - SET event_stream_ordering = ? - WHERE room_id = ? AND event_id = ? - """, - (event_stream_ordering, room_id, event_id), - ) - - self.db_pool.updates._background_update_progress_txn( - txn, - _BackgroundUpdates.POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING, - { - "stream_ordering": stop, - "max_stream_ordering": max_stream_ordering, - }, - ) - - return stop > max_stream_ordering - - finished = await self.db_pool.runInteraction( - "_populate_membership_event_stream_ordering", - _populate_membership_event_stream_ordering, - ) - - if finished: - await self.db_pool.updates._end_background_update( - _BackgroundUpdates.POPULATE_MEMBERSHIP_EVENT_STREAM_ORDERING - ) - - return batch_size diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d0ef10258..d7d08369ca 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1779,7 +1779,7 @@ class EventsWorkerStore(SQLBaseStore): txn: LoggingTransaction, ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( - "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type," + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," " e.outlier" " FROM events AS e" @@ -1791,10 +1791,10 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" - " WHERE ? < out.event_stream_ordering" - " AND out.event_stream_ordering <= ?" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" " AND out.instance_name = ?" - " ORDER BY out.event_stream_ordering ASC" + " ORDER BY event_stream_ordering ASC" ) txn.execute(sql, (last_id, current_id, instance_name)) diff --git a/synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql b/synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql deleted file mode 100644 index 7c30a67fc4..0000000000 --- a/synapse/storage/schema/main/delta/73/26membership_tables_event_stream_ordering.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2022 Beeper - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT; -ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT; -ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_membership_event_stream_ordering', '{}'); -- cgit 1.5.1 From f10caa73eee0caa91cf373966104d1ededae2aee Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Feb 2023 15:33:33 +0000 Subject: Disambiguate `get_ex_outlier_stream_rows` query A backwards-compatible piece of #14979 that's safe to land now. --- synapse/storage/databases/main/events_worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d7d08369ca..6d0ef10258 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1779,7 +1779,7 @@ class EventsWorkerStore(SQLBaseStore): txn: LoggingTransaction, ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type," " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," " e.outlier" " FROM events AS e" @@ -1791,10 +1791,10 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" + " WHERE ? < out.event_stream_ordering" + " AND out.event_stream_ordering <= ?" " AND out.instance_name = ?" - " ORDER BY event_stream_ordering ASC" + " ORDER BY out.event_stream_ordering ASC" ) txn.execute(sql, (last_id, current_id, instance_name)) -- cgit 1.5.1 From c78c67c5a909c6749f25b251d46be3df8f56f8c2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Feb 2023 17:41:55 +0100 Subject: Fix bug in replication where response is cached (#15024) --- changelog.d/15024.bugfix | 1 + synapse/replication/http/_base.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/15024.bugfix (limited to 'synapse') diff --git a/changelog.d/15024.bugfix b/changelog.d/15024.bugfix new file mode 100644 index 0000000000..dddd406322 --- /dev/null +++ b/changelog.d/15024.bugfix @@ -0,0 +1 @@ +Fix bug where retried replication requests would return a failure. Introduced in v1.76.0. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 908f3f1db7..c20d9c7e9d 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -426,6 +426,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): code, response = await self.response_cache.wrap( txn_id, self._handle_request, request, content, **kwargs ) + # Take a copy so we don't mutate things in the cache. + response = dict(response) else: # The `@cancellable` decorator may be applied to `_handle_request`. But we # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, -- cgit 1.5.1 From c951fbedcb81895c199c1f4cfe2251d6c3a7b5f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:09:41 -0500 Subject: MSC3873: Escape keys when flattening dicts. (#15004) This disambiguates keys which attempt to match fields with a dot in them (e.g. m.relates_to). Disabled by default behind an experimental configuration flag. --- changelog.d/15004.feature | 1 + synapse/config/experimental.py | 5 +++++ synapse/push/bulk_push_rule_evaluator.py | 30 ++++++++++++++++++++++++++---- tests/push/test_push_rule_evaluator.py | 8 ++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) create mode 100644 changelog.d/15004.feature (limited to 'synapse') diff --git a/changelog.d/15004.feature b/changelog.d/15004.feature new file mode 100644 index 0000000000..d11d0aca91 --- /dev/null +++ b/changelog.d/15004.feature @@ -0,0 +1 @@ +Implement [MSC3873](https://github.com/matrix-org/matrix-spec-proposals/pull/3873) to unambiguate push rule keys with dots in them. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 53c0682dfd..5e3a889081 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -169,6 +169,11 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + # MSC3873: Disambiguate event_match keys. + self.msc3783_escape_event_match_key = experimental.get( + "msc3783_escape_event_match_key", False + ) + # MSC3952: Intentional mentions self.msc3952_intentional_mentions = experimental.get( "msc3952_intentional_mentions", False diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d9c0a98f44..39d2f88f03 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -271,7 +271,10 @@ class BulkPushRuleEvaluator: related_event_id, allow_none=True ) if related_event is not None: - related_events[relation_type] = _flatten_dict(related_event) + related_events[relation_type] = _flatten_dict( + related_event, + msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + ) reply_event_id = ( event.content.get("m.relates_to", {}) @@ -286,7 +289,10 @@ class BulkPushRuleEvaluator: ) if related_event is not None: - related_events["m.in_reply_to"] = _flatten_dict(related_event) + related_events["m.in_reply_to"] = _flatten_dict( + related_event, + msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + ) # indicate that this is from a fallback relation. if relation_type == "m.thread" and event.content.get( @@ -405,7 +411,10 @@ class BulkPushRuleEvaluator: room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( - _flatten_dict(event), + _flatten_dict( + event, + msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + ), has_mentions, user_mentions, room_mention, @@ -493,6 +502,8 @@ def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, + *, + msc3783_escape_event_match_key: bool = False, ) -> Dict[str, str]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, @@ -521,11 +532,22 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): + if msc3783_escape_event_match_key: + # Escape periods in the key with a backslash (and backslashes with an + # extra backslash). This is since a period is used as a separator between + # nested fields. + key = key.replace("\\", "\\\\").replace(".", "\\.") + if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below - _flatten_dict(value, prefix=(prefix + [key]), result=result) + _flatten_dict( + value, + prefix=(prefix + [key]), + result=result, + msc3783_escape_event_match_key=msc3783_escape_event_match_key, + ) # `room_version` should only ever be set when looking at the top level of an event if ( diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index da33423871..516b65cc3c 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -48,6 +48,14 @@ class FlattenDictTestCase(unittest.TestCase): input = {"foo": {"bar": "abc"}} self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input)) + # If a field has a dot in it, escape it. + input = {"m.foo": {"b\\ar": "abc"}} + self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input)) + self.assertEqual( + {"m\\.foo.b\\\\ar": "abc"}, + _flatten_dict(input, msc3783_escape_event_match_key=True), + ) + def test_non_string(self) -> None: """Non-string items are dropped.""" input: Dict[str, Any] = { -- cgit 1.5.1 From 55e4d27b36fd69a3cf3eceecbd42706579ef2dc7 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 8 Feb 2023 11:25:11 -0800 Subject: Limit concurrent event creation for a room to avoid state resolution when sending bursts of events to a local room (#14977) --- changelog.d/14977.misc | 1 + synapse/handlers/message.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14977.misc (limited to 'synapse') diff --git a/changelog.d/14977.misc b/changelog.d/14977.misc new file mode 100644 index 0000000000..4d551c52b7 --- /dev/null +++ b/changelog.d/14977.misc @@ -0,0 +1 @@ +Limit concurrent event creation for a room to avoid state resolution when sending bursts of events to a local room. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e688e00575..5f6da2943f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -499,9 +499,9 @@ class EventCreationHandler: self.request_ratelimiter = hs.get_request_ratelimiter() - # We arbitrarily limit concurrent event creation for a room to 5. - # This is to stop us from diverging history *too* much. - self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") + # We limit concurrent event creation for a room to 1. This prevents state resolution + # from occurring when sending bursts of events to a local room + self.limiter = Linearizer(max_count=1, name="room_event_creation_limit") self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() -- cgit 1.5.1 From 733531ee3e695da92f10e01b24f62ee35e09e4cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Feb 2023 09:49:04 -0500 Subject: Add final type hint to synapse.server. (#15035) --- changelog.d/15035.misc | 1 + mypy.ini | 3 --- synapse/handlers/room.py | 2 +- synapse/server.py | 12 +++++------- synapse/storage/_base.py | 2 ++ synapse/storage/database.py | 1 + synapse/storage/databases/main/events.py | 2 +- 7 files changed, 11 insertions(+), 12 deletions(-) create mode 100644 changelog.d/15035.misc (limited to 'synapse') diff --git a/changelog.d/15035.misc b/changelog.d/15035.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/15035.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 3f144e61fb..57f27ba4f7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -53,9 +53,6 @@ warn_unused_ignores = False [mypy-synapse.util.caches.treecache] disallow_untyped_defs = False -[mypy-synapse.server] -disallow_untyped_defs = False - [mypy-synapse.storage.database] disallow_untyped_defs = False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7ba7c4ff07..0e759b8a5d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1076,7 +1076,7 @@ class RoomCreationHandler: state_map: MutableStateMap[str] = {} # current_state_group of last event created. Used for computing event context of # events to be batched - current_state_group = None + current_state_group: Optional[int] = None def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} diff --git a/synapse/server.py b/synapse/server.py index 9d6d268f49..efc6b5f895 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -21,7 +21,7 @@ import abc import functools import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port @@ -144,10 +144,10 @@ if TYPE_CHECKING: from synapse.handlers.saml import SamlHandler -T = TypeVar("T", bound=Callable[..., Any]) +T = TypeVar("T") -def cache_in_self(builder: T) -> T: +def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]: """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and returning if so. If not, calls the given function and sets `self.foo` to it. @@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T: building = [False] @functools.wraps(builder) - def _get(self): + def _get(self: "HomeServer") -> T: try: return getattr(self, depname) except AttributeError: @@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T: return dep - # We cast here as we need to tell mypy that `_get` has the same signature as - # `builder`. - return cast(T, _get) + return _get class HomeServer(metaclass=abc.ABCMeta): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 41d9111019..481fec72fe 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ + db_pool: DatabasePool + def __init__( self, database: DatabasePool, diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e20c5c5302..feaa6cdd07 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -499,6 +499,7 @@ class DatabasePool: """ _TXN_ID = 0 + engine: BaseDatabaseEngine def __init__( self, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1536937b67..cb66376fb4 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -306,7 +306,7 @@ class PersistEventsStore: # The set of event_ids to return. This includes all soft-failed events # and their prev events. - existing_prevs = set() + existing_prevs: Set[str] = set() def _get_prevs_before_rejected_txn( txn: LoggingTransaction, batch: Collection[str] -- cgit 1.5.1 From cd2484dc2e943e40242337dae61f5170638116a2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 9 Feb 2023 15:28:26 +0000 Subject: Bump schema version (#15036) * Bump schema version This should have been included in f10caa73eee0caa91cf373966104d1ededae2aee (and #14979). * Changelog --- changelog.d/15036.misc | 1 + synapse/storage/schema/__init__.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 changelog.d/15036.misc (limited to 'synapse') diff --git a/changelog.d/15036.misc b/changelog.d/15036.misc new file mode 100644 index 0000000000..b0adc9c9d1 --- /dev/null +++ b/changelog.d/15036.misc @@ -0,0 +1 @@ +Prepare for future database schema changes. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 19dbf2da7f..d3103a6c7a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 73 # remember to update the list below when updating +SCHEMA_VERSION = 74 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -78,7 +78,7 @@ Changes in SCHEMA_VERSION = 72: - Unused column application_services_state.last_txn is dropped - Cache invalidation stream id sequence now begins at 2 to match code expectation. -Changes in SCHEMA_VERSION = 73; +Changes in SCHEMA_VERSION = 73: - thread_id column is added to event_push_actions, event_push_actions_staging event_push_summary, receipts_linearized, and receipts_graph. - Add table `event_failed_pull_attempts` to keep track when we fail to pull @@ -86,6 +86,11 @@ Changes in SCHEMA_VERSION = 73; - Add indexes to various tables (`event_failed_pull_attempts`, `insertion_events`, `batch_events`) to make it easy to delete all associated rows when purging a room. - `inserted_ts` column is added to `event_push_actions_staging` table. + +Changes in SCHEMA_VERSION = 74: + - A query on `event_stream_ordering` column has now been disambiguated (i.e. the + codebase can handle the `current_state_events`, `local_current_memberships` and + `room_memberships` tables having an `event_stream_ordering` column). """ -- cgit 1.5.1 From 8a6e0434889ea94893119775b6f56904cbc575c2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Feb 2023 10:56:02 -0500 Subject: Avoid mutating cached room aliases. (#15038) This might cause incorrect data in other callers which are not expecting the canonical alias to be added into the response. --- changelog.d/15038.bugfix | 1 + synapse/handlers/directory.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/15038.bugfix (limited to 'synapse') diff --git a/changelog.d/15038.bugfix b/changelog.d/15038.bugfix new file mode 100644 index 0000000000..4695a09756 --- /dev/null +++ b/changelog.d/15038.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the room aliases returned could be corrupted. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 2ea52257cb..d31b0fbb17 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -485,7 +485,8 @@ class DirectoryHandler: ) ) if canonical_alias: - room_aliases.append(canonical_alias) + # Ensure we do not mutate room_aliases. + room_aliases = room_aliases + [canonical_alias] if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases -- cgit 1.5.1 From d22c1c862c8259465a8e95c41eb1f00d0367a640 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Feb 2023 13:04:24 -0500 Subject: Respond correctly to unknown methods on known endpoints (#14605) Respond with a 405 error if a request is received on a known endpoint, but to an unknown method, per MSC3743. --- changelog.d/14605.bugfix | 1 + docs/admin_api/media_admin_api.md | 10 +++++++- docs/upgrade.md | 10 ++++++++ synapse/http/server.py | 40 ++++++++++++-------------------- synapse/rest/admin/media.py | 18 +++++++++++---- synapse/rest/client/room_keys.py | 48 ++++++++++++++++++++++++++------------- synapse/rest/client/tags.py | 4 +++- tests/rest/admin/test_media.py | 9 +++++--- 8 files changed, 89 insertions(+), 51 deletions(-) create mode 100644 changelog.d/14605.bugfix (limited to 'synapse') diff --git a/changelog.d/14605.bugfix b/changelog.d/14605.bugfix new file mode 100644 index 0000000000..cb95a87d92 --- /dev/null +++ b/changelog.d/14605.bugfix @@ -0,0 +1 @@ +Return spec-compliant JSON errors when unknown endpoints are requested. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 7f8c8e22c1..30833f3109 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -235,6 +235,14 @@ The following fields are returned in the JSON response body: Request: +``` +POST /_synapse/admin/v1/media/delete?before_ts= + +{} +``` + +*Deprecated in Synapse v1.78.0:* This API is available at the deprecated endpoint: + ``` POST /_synapse/admin/v1/media//delete?before_ts= @@ -243,7 +251,7 @@ POST /_synapse/admin/v1/media//delete?before_ts= URL Parameters -* `server_name`: string - The name of your local server (e.g `matrix.org`). +* `server_name`: string - The name of your local server (e.g `matrix.org`). *Deprecated in Synapse v1.78.0.* * `before_ts`: string representing a positive integer - Unix timestamp in milliseconds. Files that were last used before this timestamp will be deleted. It is the timestamp of last access, not the timestamp when the file was created. diff --git a/docs/upgrade.md b/docs/upgrade.md index bc143444be..15167b8c58 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.78.0 + +## Deprecate the `/_synapse/admin/v1/media//delete` admin API + +Synapse 1.78.0 replaces the `/_synapse/admin/v1/media//delete` +admin API with an identical endpoint at `/_synapse/admin/v1/media/delete`. Please +update your tooling to use the new endpoint. The deprecated version will be removed +in a future release. + # Upgrading to v1.76.0 ## Faster joins are enabled by default @@ -137,6 +146,7 @@ and then do `pip install matrix-synapse[user-search]` for a PyPI install. Docker images and Debian packages need nothing specific as they already include or specify ICU as an explicit dependency. + # Upgrading to v1.73.0 ## Legacy Prometheus metric names have now been removed diff --git a/synapse/http/server.py b/synapse/http/server.py index 2563858f3c..9314454af1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -30,7 +30,6 @@ from typing import ( Iterable, Iterator, List, - NoReturn, Optional, Pattern, Tuple, @@ -340,7 +339,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - return _unrecognised_request_handler(request) + # A request with an unknown method (for a known endpoint) was received. + raise UnrecognizedRequestError(code=405) @abc.abstractmethod def _send_response( @@ -396,7 +396,6 @@ class DirectServeJsonResource(_AsyncResource): @attr.s(slots=True, frozen=True, auto_attribs=True) class _PathEntry: - pattern: Pattern callback: ServletCallback servlet_classname: str @@ -425,13 +424,14 @@ class JsonResource(DirectServeJsonResource): ): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() - self.path_regexs: Dict[bytes, List[_PathEntry]] = {} + # Map of path regex -> method -> callback. + self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {} self.hs = hs def register_paths( self, method: str, - path_patterns: Iterable[Pattern], + path_patterns: Iterable[Pattern[str]], callback: ServletCallback, servlet_classname: str, ) -> None: @@ -455,8 +455,8 @@ class JsonResource(DirectServeJsonResource): for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) - self.path_regexs.setdefault(method_bytes, []).append( - _PathEntry(path_pattern, callback, servlet_classname) + self._routes.setdefault(path_pattern, {})[method_bytes] = _PathEntry( + callback, servlet_classname ) def _get_handler_for_request( @@ -478,14 +478,17 @@ class JsonResource(DirectServeJsonResource): # Loop through all the registered callbacks to check if the method # and path regex match - for path_entry in self.path_regexs.get(request_method, []): - m = path_entry.pattern.match(request_path) + for path_pattern, methods in self._routes.items(): + m = path_pattern.match(request_path) if m: - # We found a match! + # We found a matching path! + path_entry = methods.get(request_method) + if not path_entry: + raise UnrecognizedRequestError(code=405) return path_entry.callback, path_entry.servlet_classname, m.groupdict() - # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - return _unrecognised_request_handler, "unrecognised_request_handler", {} + # Huh. No one wanted to handle that? Fiiiiiine. + raise UnrecognizedRequestError(code=404) async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) @@ -567,19 +570,6 @@ class StaticResource(File): return super().render_GET(request) -def _unrecognised_request_handler(request: Request) -> NoReturn: - """Request handler for unrecognised requests - - This is a request handler suitable for return from - _get_handler_for_request. It actually just raises an - UnrecognizedRequestError. - - Args: - request: Unused, but passed in to match the signature of ServletCallback. - """ - raise UnrecognizedRequestError(code=404) - - class UnrecognizedRequestResource(resource.Resource): """ Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 0d072c42a7..c134ccfb3d 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,7 +15,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") + PATTERNS = [ + *admin_patterns("/media/delete$"), + # This URL kept around for legacy reasons, it is undesirable since it + # overlaps with the DeleteMediaByID servlet. + *admin_patterns("/media/(?P[^/]*)/delete$"), + ] def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main @@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet): self.media_repository = hs.get_media_repository() async def on_POST( - self, request: SynapseRequest, server_name: str + self, request: SynapseRequest, server_name: Optional[str] = None ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet): errcode=Codes.INVALID_PARAM, ) - if self.server_name != server_name: + # This check is useless, we keep it for the legacy endpoint only. + if server_name is not None and self.server_name != server_name: raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( @@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) ProtectMediaByID(hs).register(http_server) UnprotectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) - DeleteMediaByID(hs).register(http_server) + # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as + # their URL routes overlap. DeleteMediaByDateSize(hs).register(http_server) + DeleteMediaByID(hs).register(http_server) UserMediaRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index f7081f638e..4e7ffdb555 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + """ + Retrieve the version information about the most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = await self.e2e_room_keys_handler.get_version_info(user_id) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + return 200, info + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given @@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") + PATTERNS = client_patterns("/room_keys/version/(?P[^/]+)$") def __init__(self, hs: "HomeServer"): super().__init__() @@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet): self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() async def on_GET( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's - room_keys backup. If the version part is missing, returns info about the - most current backup version (if any) + room_keys backup. It takes out an exclusive lock on this user's room_key backups, to ensure clients only upload to the current backup. @@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet): return 200, info async def on_DELETE( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's - room_keys backup. If the version part is missing, deletes the most - current backup version (if any). Doesn't delete the actual room data. + room_keys backup. Doesn't delete the actual room data. DELETE /room_keys/version/12345 HTTP/1.1 HTTP/1.1 200 OK {} """ - if version is None: - raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() @@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} async def on_PUT( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - if version is None: - raise SynapseError( - 400, "No version specified to update", Codes.MISSING_PARAM - ) - await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index ca638755c7..dde08417a4 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -34,7 +34,9 @@ class TagListServlet(RestServlet): GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") + PATTERNS = client_patterns( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags$" + ) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index aadb31ca83..db77a45ae3 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + self.url = "/_synapse/admin/v1/media/delete" + self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name # Move clock up to somewhat realistic time self.reactor.advance(1000000000) @@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_delete_media_never_accessed(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` """ + url = self.legacy_url if use_legacy_url else self.url # upload and do not access server_and_media_id = self._create_media() @@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): now_ms = self.clock.time_msec() channel = self.make_request( "POST", - self.url + "?before_ts=" + str(now_ms), + url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) -- cgit 1.5.1 From c1d2ce2901ab1c7cfaeebb4683af05a2ebf19fa6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 9 Feb 2023 19:57:01 +0000 Subject: Do not always start a db txn on Postgres (#14840) --- changelog.d/14840.misc | 1 + synapse/storage/prepare_database.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14840.misc (limited to 'synapse') diff --git a/changelog.d/14840.misc b/changelog.d/14840.misc new file mode 100644 index 0000000000..ff6084284a --- /dev/null +++ b/changelog.d/14840.misc @@ -0,0 +1 @@ +Prevent "WARNING: there is already a transaction in progress" lines appearing in PostgreSQL's logs on some occasions. \ No newline at end of file diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 3acdb39da7..6c335a9315 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION from synapse.storage.types import Cursor @@ -108,9 +108,14 @@ def prepare_database( # so we start one before running anything. This ensures that any upgrades # are either applied completely, or not at all. # - # (psycopg2 automatically starts a transaction as soon as we run any statements - # at all, so this is redundant but harmless there.) - cur.execute("BEGIN TRANSACTION") + # psycopg2 does not automatically start transactions when in autocommit mode. + # While it is technically harmless to nest transactions in postgres, doing so + # results in a warning in Postgres' logs per query. And we'd rather like to + # avoid doing that. + if isinstance(database_engine, Sqlite3Engine) or ( + isinstance(database_engine, PostgresEngine) and db_conn.autocommit + ): + cur.execute("BEGIN TRANSACTION") logger.info("%r: Checking existing schema version", databases) version_info = _get_or_create_schema_state(cur, database_engine) -- cgit 1.5.1 From 03bccd542bcffe3ea12cd35108740a7d62dd38ab Mon Sep 17 00:00:00 2001 From: Shay Date: Thu, 9 Feb 2023 13:05:02 -0800 Subject: Add a class UnpersistedEventContext to allow for the batching up of storing state groups (#14675) * add class UnpersistedEventContext * modify create new client event to create unpersistedeventcontexts * persist event contexts after creation * fix tests to persist unpersisted event contexts * cleanup * misc lints + cleanup * changelog + fix comments * lints * fix batch insertion? * reduce redundant calculation * add unpersisted event classes * rework compute_event_context, split into function that returns unpersisted event context and then persists it * use calculate_context_info to create unpersisted event contexts * update typing * $%#^&* * black * fix comments and consolidate classes, use attr.s for class * requested changes * lint * requested changes * requested changes * refactor to be stupidly explicit * clearer renaming and flow * make partial state non-optional * update docstrings --------- Co-authored-by: Erik Johnston --- changelog.d/14675.misc | 1 + synapse/events/snapshot.py | 174 ++++++++++++++++++++++++++++++++- synapse/events/third_party_rules.py | 6 +- synapse/handlers/federation.py | 59 ++++++++---- synapse/handlers/federation_event.py | 6 +- synapse/handlers/message.py | 42 +++++--- synapse/state/__init__.py | 176 ++++++++++++++-------------------- tests/handlers/test_user_directory.py | 4 +- tests/rest/admin/test_user.py | 4 +- tests/storage/test_redaction.py | 24 +++-- tests/storage/test_state.py | 4 +- tests/test_utils/event_injection.py | 7 +- tests/test_visibility.py | 9 +- tests/utils.py | 5 +- 14 files changed, 359 insertions(+), 162 deletions(-) create mode 100644 changelog.d/14675.misc (limited to 'synapse') diff --git a/changelog.d/14675.misc b/changelog.d/14675.misc new file mode 100644 index 0000000000..bc1ac1c82a --- /dev/null +++ b/changelog.d/14675.misc @@ -0,0 +1 @@ +Add a class UnpersistedEventContext to allow for the batching up of storing state groups. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6eaef8b57a..e0d82ad81c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional, Tuple import attr @@ -26,8 +27,51 @@ if TYPE_CHECKING: from synapse.types.state import StateFilter +class UnpersistedEventContextBase(ABC): + """ + This is a base class for EventContext and UnpersistedEventContext, objects which + hold information relevant to storing an associated event. Note that an + UnpersistedEventContexts must be converted into an EventContext before it is + suitable to send to the db with its associated event. + + Attributes: + _storage: storage controllers for interfacing with the database + app_service: If the associated event is being sent by a (local) application service, that + app service. + """ + + def __init__(self, storage_controller: "StorageControllers"): + self._storage: "StorageControllers" = storage_controller + self.app_service: Optional[ApplicationService] = None + + @abstractmethod + async def persist( + self, + event: EventBase, + ) -> "EventContext": + """ + A method to convert an UnpersistedEventContext to an EventContext, suitable for + sending to the database with the associated event. + """ + pass + + @abstractmethod + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: + """ + Gets the room state at the event (ie not including the event if the event is a + state event). + + Args: + state_filter: specifies the type of state event to fetch from DB, example: + EventTypes.JoinRules + """ + pass + + @attr.s(slots=True, auto_attribs=True) -class EventContext: +class EventContext(UnpersistedEventContextBase): """ Holds information relevant to persisting an event @@ -77,9 +121,6 @@ class EventContext: delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` and ``state_group``. - app_service: If this event is being sent by a (local) application service, that - app service. - partial_state: if True, we may be storing this event with a temporary, incomplete state. """ @@ -122,6 +163,9 @@ class EventContext: """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) + async def persist(self, event: EventBase) -> "EventContext": + return self + async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -254,6 +298,128 @@ class EventContext: ) +@attr.s(slots=True, auto_attribs=True) +class UnpersistedEventContext(UnpersistedEventContextBase): + """ + The event context holds information about the state groups for an event. It is important + to remember that an event technically has two state groups: the state group before the + event, and the state group after the event. If the event is not a state event, the state + group will not change (ie the state group before the event will be the same as the state + group after the event), but if it is a state event the state group before the event + will differ from the state group after the event. + This is a version of an EventContext before the new state group (if any) has been + computed and stored. It contains information about the state before the event (which + also may be the information after the event, if the event is not a state event). The + UnpersistedEventContext must be converted into an EventContext by calling the method + 'persist' on it before it is suitable to be sent to the DB for processing. + + state_group_after_event: + The state group after the event. This will always be None until it is persisted. + If the event is not a state event, this will be the same as + state_group_before_event. + + state_group_before_event: + The ID of the state group representing the state of the room before this event. + + state_delta_due_to_event: + If the event is a state event, then this is the delta of the state between + `state_group` and `state_group_before_event` + + prev_group_for_state_group_before_event: + If it is known, ``state_group_before_event``'s previous state group. + + delta_ids_to_state_group_before_event: + If ``prev_group_for_state_group_before_event`` is not None, the state delta + between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``. + + partial_state: + Whether the event has partial state. + + state_map_before_event: + A map of the state before the event, i.e. the state at `state_group_before_event` + """ + + _storage: "StorageControllers" + state_group_before_event: Optional[int] + state_group_after_event: Optional[int] + state_delta_due_to_event: Optional[dict] + prev_group_for_state_group_before_event: Optional[int] + delta_ids_to_state_group_before_event: Optional[StateMap[str]] + partial_state: bool + state_map_before_event: Optional[StateMap[str]] = None + + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: + """ + Gets the room state map, excluding this event. + + Args: + state_filter: specifies the type of state event to fetch from DB + + Returns: + Maps a (type, state_key) to the event ID of the state event matching + this tuple. + """ + if self.state_map_before_event: + return self.state_map_before_event + + assert self.state_group_before_event is not None + return await self._storage.state.get_state_ids_for_group( + self.state_group_before_event, state_filter + ) + + async def persist(self, event: EventBase) -> EventContext: + """ + Creates a full `EventContext` for the event, persisting any referenced state that + has not yet been persisted. + + Args: + event: event that the EventContext is associated with. + + Returns: An EventContext suitable for sending to the database with the event + for persisting + """ + assert self.partial_state is not None + + # If we have a full set of state for before the event but don't have a state + # group for that state, we need to get one + if self.state_group_before_event is None: + assert self.state_map_before_event + state_group_before_event = await self._storage.state.store_state_group( + event.event_id, + event.room_id, + prev_group=self.prev_group_for_state_group_before_event, + delta_ids=self.delta_ids_to_state_group_before_event, + current_state_ids=self.state_map_before_event, + ) + self.state_group_before_event = state_group_before_event + + # if the event isn't a state event the state group doesn't change + if not self.state_delta_due_to_event: + state_group_after_event = self.state_group_before_event + + # otherwise if it is a state event we need to get a state group for it + else: + state_group_after_event = await self._storage.state.store_state_group( + event.event_id, + event.room_id, + prev_group=self.state_group_before_event, + delta_ids=self.state_delta_due_to_event, + current_state_ids=None, + ) + + return EventContext.with_state( + storage=self._storage, + state_group=state_group_after_event, + state_group_before_event=self.state_group_before_event, + state_delta_due_to_event=self.state_delta_due_to_event, + partial_state=self.partial_state, + prev_group=self.state_group_before_event, + delta_ids=self.state_delta_due_to_event, + ) + + def _encode_state_dict( state_dict: Optional[StateMap[str]], ) -> Optional[List[Tuple[str, str, str]]]: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 72ab696898..97c61cc258 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import UnpersistedEventContextBase from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import delay_cancellation, maybe_awaitable @@ -231,7 +231,9 @@ class ThirdPartyEventRules: self._on_threepid_bind_callbacks.append(on_threepid_bind) async def check_event_allowed( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: UnpersistedEventContextBase, ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7f64130e0a..43ed4a3dd1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -56,7 +56,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict @@ -990,7 +990,10 @@ class FederationHandler: ) try: - event, context = await self.event_creation_handler.create_new_client_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event( builder=builder ) except SynapseError as e: @@ -998,7 +1001,9 @@ class FederationHandler: raise # Ensure the user can even join the room. - await self._federation_event_handler.check_join_restrictions(context, event) + await self._federation_event_handler.check_join_restrictions( + unpersisted_context, event + ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1178,7 +1183,7 @@ class FederationHandler: }, ) - event, context = await self.event_creation_handler.create_new_client_event( + event, _ = await self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1228,12 +1233,13 @@ class FederationHandler: }, ) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event(builder=builder) event_allowed, _ = await self.third_party_event_rules.check_event_allowed( - event, context + event, unpersisted_context ) if not event_allowed: logger.warning("Creation of knock %s forbidden by third-party rules", event) @@ -1406,15 +1412,20 @@ class FederationHandler: try: ( event, - context, + unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context + ( + event, + unpersisted_context, + ) = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, unpersisted_context ) + context = await unpersisted_context.persist(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even @@ -1483,14 +1494,19 @@ class FederationHandler: try: ( event, - context, + unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context + ( + event, + unpersisted_context, + ) = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, unpersisted_context ) + context = await unpersisted_context.persist(event) + try: validate_event_for_room_version(event) await self._event_auth_handler.check_auth_rules_from_context(event) @@ -1522,8 +1538,8 @@ class FederationHandler: room_version_obj: RoomVersion, event_dict: JsonDict, event: EventBase, - context: EventContext, - ) -> Tuple[EventBase, EventContext]: + context: UnpersistedEventContextBase, + ) -> Tuple[EventBase, UnpersistedEventContextBase]: key = ( EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], @@ -1557,11 +1573,14 @@ class FederationHandler: room_version_obj, event_dict ) EventValidator().validate_builder(builder) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) + + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event(builder=builder) + EventValidator().validate_new(event, self.config) - return event, context + return event, unpersisted_context async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e037acbca2..3561f2f1de 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -58,7 +58,7 @@ from synapse.event_auth import ( validate_event_for_room_version, ) from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import ( @@ -426,7 +426,9 @@ class FederationEventHandler: return event, context async def check_join_restrictions( - self, context: EventContext, event: EventBase + self, + context: UnpersistedEventContextBase, + event: EventBase, ) -> None: """Check that restrictions in restricted join rules are matched diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5f6da2943f..3e30f52e4d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -48,7 +48,7 @@ from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler @@ -708,7 +708,7 @@ class EventCreationHandler: builder.internal_metadata.historical = historical - event, context = await self.create_new_client_event( + event, unpersisted_context = await self.create_new_client_event( builder=builder, requester=requester, allow_no_prev_events=allow_no_prev_events, @@ -721,6 +721,8 @@ class EventCreationHandler: current_state_group=current_state_group, ) + context = await unpersisted_context.persist(event) + # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -1083,13 +1085,14 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for the event using the parameters state_map and current_state_group, thus these parameters must be provided in this case if for_batch is True. The subsequently created event and context are suitable for being batched up and bulk persisted to the database - with other similarly created events. + with other similarly created events. Note that this returns an UnpersistedEventContext, + which must be converted to an EventContext before it can be sent to the DB. Args: builder: @@ -1131,7 +1134,7 @@ class EventCreationHandler: batch persisting Returns: - Tuple of created event, context + Tuple of created event, UnpersistedEventContext """ # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender @@ -1192,9 +1195,16 @@ class EventCreationHandler: event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth ) - context = await self.state.compute_event_context_for_batched( - event, state_map, current_state_group + + context: UnpersistedEventContextBase = ( + await self.state.calculate_context_info( + event, + state_ids_before_event=state_map, + partial_state=False, + state_group_before_event=current_state_group, + ) ) + else: event = await builder.build( prev_event_ids=prev_event_ids, @@ -1244,16 +1254,17 @@ class EventCreationHandler: state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + context = await self.state.calculate_context_info( event, state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 partial_state=False, ) + else: - context = await self.state.compute_event_context(event) + context = await self.state.calculate_context_info(event) if requester: context.app_service = requester.app_service @@ -2082,9 +2093,9 @@ class EventCreationHandler: async def _rebuild_event_after_third_party_rules( self, third_party_result: dict, original_event: EventBase - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: # the third_party_event_rules want to replace the event. - # we do some basic checks, and then return the replacement event and context. + # we do some basic checks, and then return the replacement event. # Construct a new EventBuilder and validate it, which helps with the # rest of these checks. @@ -2138,5 +2149,6 @@ class EventCreationHandler: # we rebuild the event context, to be on the safe side. If nothing else, # delta_ids might need an update. - context = await self.state.compute_event_context(event) + context = await self.state.calculate_context_info(event) + return event, context diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fdfb46ab82..e877e6f1a1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import ( + EventContext, + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 @@ -262,31 +266,31 @@ class StateHandler: state = await entry.get_state(self._state_storage_controller, StateFilter.all()) return await self.store.get_joined_hosts(room_id, state, entry) - async def compute_event_context( + async def calculate_context_info( self, event: EventBase, state_ids_before_event: Optional[StateMap[str]] = None, partial_state: Optional[bool] = None, - ) -> EventContext: - """Build an EventContext structure for a non-outlier event. - - (for an outlier, call EventContext.for_outlier directly) - - This works out what the current state should be for the event, and - generates a new state group if necessary. - - Args: - event: - state_ids_before_event: The event ids of the state before the event if - it can't be calculated from existing events. This is normally - only specified when receiving an event from federation where we - don't have the prev events, e.g. when backfilling. - partial_state: - `True` if `state_ids_before_event` is partial and omits non-critical - membership events. - `False` if `state_ids_before_event` is the full state. - `None` when `state_ids_before_event` is not provided. In this case, the - flag will be calculated based on `event`'s prev events. + state_group_before_event: Optional[int] = None, + ) -> UnpersistedEventContextBase: + """ + Calulates the contents of an unpersisted event context, other than the current + state group (which is either provided or calculated when the event context is persisted) + + state_ids_before_event: + The event ids of the full state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling or when the event + is being created for batch persisting. + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. + state_group_before_event: + the current state group at the time of event, if known Returns: The event context. @@ -294,7 +298,6 @@ class StateHandler: RuntimeError if `state_ids_before_event` is not provided and one or more prev events are missing or outliers. """ - assert not event.internal_metadata.is_outlier() # @@ -306,17 +309,6 @@ class StateHandler: state_group_before_event_prev_group = None deltas_to_state_group_before_event = None - # .. though we need to get a state group for it. - state_group_before_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=state_ids_before_event, - ) - ) - # the partial_state flag must be provided assert partial_state is not None else: @@ -345,6 +337,7 @@ class StateHandler: logger.debug("calling resolve_state_groups from compute_event_context") # we've already taken into account partial state, so no need to wait for # complete state here. + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids(), @@ -383,18 +376,19 @@ class StateHandler: # if not event.is_state(): - return EventContext.with_state( + return UnpersistedEventContext( storage=self._storage_controllers, state_group_before_event=state_group_before_event, - state_group=state_group_before_event, + state_group_after_event=state_group_before_event, state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, + prev_group_for_state_group_before_event=state_group_before_event_prev_group, + delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, partial_state=partial_state, + state_map_before_event=state_ids_before_event, ) # - # otherwise, we'll need to create a new state group for after the event + # otherwise, we'll need to set up creating a new state group for after the event # key = (event.type, event.state_key) @@ -412,88 +406,60 @@ class StateHandler: delta_ids = {key: event.event_id} - state_group_after_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=None, - ) - ) - - return EventContext.with_state( + return UnpersistedEventContext( storage=self._storage_controllers, - state_group=state_group_after_event, state_group_before_event=state_group_before_event, + state_group_after_event=None, state_delta_due_to_event=delta_ids, - prev_group=state_group_before_event, - delta_ids=delta_ids, + prev_group_for_state_group_before_event=state_group_before_event_prev_group, + delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, partial_state=partial_state, + state_map_before_event=state_ids_before_event, ) - async def compute_event_context_for_batched( + async def compute_event_context( self, event: EventBase, - state_ids_before_event: StateMap[str], - current_state_group: int, + state_ids_before_event: Optional[StateMap[str]] = None, + partial_state: Optional[bool] = None, ) -> EventContext: - """ - Generate an event context for an event that has not yet been persisted to the - database. Intended for use with events that are created to be persisted in a batch. - Args: - event: the event the context is being computed for - state_ids_before_event: a state map consisting of the state ids of the events - created prior to this event. - current_state_group: the current state group before the event. - """ - state_group_before_event_prev_group = None - deltas_to_state_group_before_event = None - - state_group_before_event = current_state_group - - # if the event is not state, we are set - if not event.is_state(): - return EventContext.with_state( - storage=self._storage_controllers, - state_group_before_event=state_group_before_event, - state_group=state_group_before_event, - state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - partial_state=False, - ) + """Build an EventContext structure for a non-outlier event. - # otherwise, we'll need to create a new state group for after the event - key = (event.type, event.state_key) + (for an outlier, call EventContext.for_outlier directly) - if state_ids_before_event is not None: - replaces = state_ids_before_event.get(key) + This works out what the current state should be for the event, and + generates a new state group if necessary. - if replaces and replaces != event.event_id: - event.unsigned["replaces_state"] = replaces + Args: + event: + state_ids_before_event: The event ids of the state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling. + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. + entry: + A state cache entry for the resolved state across the prev events. We may + have already calculated this, so if it's available pass it in + Returns: + The event context. - delta_ids = {key: event.event_id} + Raises: + RuntimeError if `state_ids_before_event` is not provided and one or more + prev events are missing or outliers. + """ - state_group_after_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=None, - ) + unpersisted_context = await self.calculate_context_info( + event=event, + state_ids_before_event=state_ids_before_event, + partial_state=partial_state, ) - return EventContext.with_state( - storage=self._storage_controllers, - state_group=state_group_after_event, - state_group_before_event=state_group_before_event, - state_delta_due_to_event=delta_ids, - prev_group=state_group_before_event, - delta_ids=delta_ids, - partial_state=False, - ) + return await unpersisted_context.persist(event) @measure_func() async def resolve_state_groups_for_events( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 75fc5a17a4..e9be5fb504 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success( self.hs.get_storage_controllers().persistence.persist_event(event, context) ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 5c1ced355f..b50406e129 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index df4740f9d9..0100f7da14 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def internal_metadata(self) -> _EventInternalMetadata: return self._base_builder.internal_metadata - event_1, context_1 = self.get_success( + event_1, unpersisted_context_1 = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) + context_1 = self.get_success(unpersisted_context_1.persist(event_1)) + self.get_success(self._persistence.persist_event(event_1, context_1)) - event_2, context_2 = self.get_success( + event_2, unpersisted_context_2 = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) + + context_2 = self.get_success(unpersisted_context_2.persist(event_2)) self.get_success(self._persistence.persist_event(event_2, context_2)) # fetch one of the redactions @@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - redaction_event, context = self.get_success( + redaction_event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(redaction_event)) + self.get_success(self._persistence.persist_event(redaction_event, context)) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index bad7f0bc60..f730b888f7 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + assert self.storage.persistence is not None self.get_success(self.storage.persistence.persist_event(event, context)) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 1a50c2acf1..a6330ed840 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -92,8 +92,13 @@ async def create_event( builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - event, context = await hs.get_event_creation_handler().create_new_client_event( + ( + event, + unpersisted_context, + ) = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) + context = await unpersisted_context.persist(event) + return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 875e37988f..36d6b37aa4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self._storage_controllers.persistence.persist_event(event, context) ) @@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self._storage_controllers.persistence.persist_event(event, context) @@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self._storage_controllers.persistence.persist_event(event, context) diff --git a/tests/utils.py b/tests/utils.py index d76bf9716a..15fabbc2d0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: }, ) - event, context = await event_creation_handler.create_new_client_event(builder) + event, unpersisted_context = await event_creation_handler.create_new_client_event( + builder + ) + context = await unpersisted_context.persist(event) await persistence_store.persist_event(event, context) -- cgit 1.5.1 From a5a799722db0c33dc61fb2c6c7282ff7e82eb2e9 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 9 Feb 2023 22:33:39 +0000 Subject: Tag federation request spans with the worker name (#15042) * Systematically include worker name as process info * Changelog * don't bother with inner setdefault --- changelog.d/15042.feature | 1 + synapse/api/auth.py | 7 ------- synapse/logging/opentracing.py | 10 +++++++++- 3 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 changelog.d/15042.feature (limited to 'synapse') diff --git a/changelog.d/15042.feature b/changelog.d/15042.feature new file mode 100644 index 0000000000..7a4de89f00 --- /dev/null +++ b/changelog.d/15042.feature @@ -0,0 +1 @@ +Tag opentracing spans for federation requests with the name of the worker serving the request. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3d7f986ac7..66e869bc2d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -32,7 +32,6 @@ from synapse.appservice import ApplicationService from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging.opentracing import ( - SynapseTags, active_span, force_tracing, start_active_span, @@ -162,12 +161,6 @@ class Auth: parent_span.set_tag( "authenticated_entity", requester.authenticated_entity ) - # We tag the Synapse instance name so that it's an easy jumping - # off point into the logs. Can also be used to filter for an - # instance that is under load. - parent_span.set_tag( - SynapseTags.INSTANCE_NAME, self.hs.get_instance_name() - ) parent_span.set_tag("user_id", requester.user.to_string()) if requester.device_id is not None: parent_span.set_tag("device_id", requester.device_id) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ef9a0dda8..6c7cf1b294 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -466,8 +466,16 @@ def init_tracer(hs: "HomeServer") -> None: STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name() ) + jaeger_config = hs.config.tracing.jaeger_config + tags = jaeger_config.setdefault("tags", {}) + + # tag the Synapse instance name so that it's an easy jumping + # off point into the logs. Can also be used to filter for an + # instance that is under load. + tags[SynapseTags.INSTANCE_NAME] = hs.get_instance_name() + config = JaegerConfig( - config=hs.config.tracing.jaeger_config, + config=jaeger_config, service_name=f"{hs.config.server.server_name} {instance_name_by_type}", scope_manager=LogContextScopeManager(), metrics_factory=PrometheusMetricsFactory(), -- cgit 1.5.1 From fd296b7343f2e557519f1ec81325ad836bcbdbf9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 10 Feb 2023 10:52:35 +0100 Subject: Fix exception on start up about device lists (#15041) Fixes #15010. --- changelog.d/15041.misc | 1 + synapse/storage/databases/main/devices.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/15041.misc (limited to 'synapse') diff --git a/changelog.d/15041.misc b/changelog.d/15041.misc new file mode 100644 index 0000000000..d602b0043a --- /dev/null +++ b/changelog.d/15041.misc @@ -0,0 +1 @@ +Fix a rare exception in logs on start up. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e8b6cc6b80..766c2052fb 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -100,6 +100,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ("device_lists_outbound_pokes", "stream_id"), ("device_lists_changes_in_room", "stream_id"), ("device_lists_remote_pending", "stream_id"), + ("device_lists_changes_converted_stream_position", "stream_id"), ], is_writer=hs.config.worker.worker_app is None, ) -- cgit 1.5.1 From a481fb9f98ad10e5e129bdc7664c59498a7332f6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Feb 2023 08:09:47 -0500 Subject: Refactor get_user_devices_from_cache to avoid mutating cached values. (#15040) The previous version of the code could mutate a cached value, but only if the input requested all devices of a user *and* a specific device. To avoid this nonsensical situation we no longer fetch a specific device ID if all of a user's devices are returned. --- changelog.d/15040.misc | 1 + synapse/handlers/e2e_keys.py | 11 +++++++---- synapse/storage/databases/main/devices.py | 31 +++++++++++++++++-------------- 3 files changed, 25 insertions(+), 18 deletions(-) create mode 100644 changelog.d/15040.misc (limited to 'synapse') diff --git a/changelog.d/15040.misc b/changelog.d/15040.misc new file mode 100644 index 0000000000..ca129b64af --- /dev/null +++ b/changelog.d/15040.misc @@ -0,0 +1 @@ +Avoid mutating a cached value in `get_user_devices_from_cache`. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d2188ca08f..43cbece21b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -159,19 +159,22 @@ class E2eKeysHandler: # A map of destination -> user ID -> device IDs. remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {} if remote_queries: - query_list: List[Tuple[str, Optional[str]]] = [] + user_ids = set() + user_and_device_ids: List[Tuple[str, str]] = [] for user_id, device_ids in remote_queries.items(): if device_ids: - query_list.extend( + user_and_device_ids.extend( (user_id, device_id) for device_id in device_ids ) else: - query_list.append((user_id, None)) + user_ids.add(user_id) ( user_ids_not_in_cache, remote_results, - ) = await self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache( + user_ids, user_and_device_ids + ) # Check that the homeserver still shares a room with all cached users. # Note that this check may be slightly racy when a remote user leaves a diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 766c2052fb..85c1778a81 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -746,42 +746,45 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @trace @cancellable async def get_user_devices_from_cache( - self, query_list: List[Tuple[str, Optional[str]]] + self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list: List of (user_id, device_ids), if device_ids is - falsey then return all device ids for that user. + user_ids: users which should have all device IDs returned + user_and_device_ids: List of (user_id, device_ids) Returns: A tuple of (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info. """ - user_ids = {user_id for user_id, _ in query_list} - user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids} + user_map = await self.get_device_list_last_stream_id_for_remotes( + list(unique_user_ids) + ) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. users_needing_resync = await self.get_user_ids_requiring_device_list_resync( - user_ids + unique_user_ids ) user_ids_in_cache = { user_id for user_id, stream_id in user_map.items() if stream_id } - users_needing_resync - user_ids_not_in_cache = user_ids - user_ids_in_cache + user_ids_not_in_cache = unique_user_ids - user_ids_in_cache + # First fetch all the users which all devices are to be returned. results: Dict[str, Dict[str, JsonDict]] = {} - for user_id, device_id in query_list: - if user_id not in user_ids_in_cache: - continue - - if device_id: + for user_id in user_ids: + if user_id in user_ids_in_cache: + results[user_id] = await self.get_cached_devices_for_user(user_id) + # Then fetch all device-specific requests, but skip users we've already + # fetched all devices for. + for user_id, device_id in user_and_device_ids: + if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device - else: - results[user_id] = await self.get_cached_devices_for_user(user_id) set_tag("in_cache", str(results)) set_tag("not_in_cache", str(user_ids_not_in_cache)) -- cgit 1.5.1 From b95407908dfde97e483952722b6fa7a533ff5093 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 10 Feb 2023 13:11:20 +0000 Subject: Avoid mutating cached values in `_generate_sync_entry_for_account_data` (#15047) --- changelog.d/15047.misc | 1 + synapse/handlers/sync.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/15047.misc (limited to 'synapse') diff --git a/changelog.d/15047.misc b/changelog.d/15047.misc new file mode 100644 index 0000000000..561dc874de --- /dev/null +++ b/changelog.d/15047.misc @@ -0,0 +1 @@ +Avoid mutating cached values in `_generate_sync_entry_for_account_data`. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3566537894..202b35eee6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1753,6 +1753,7 @@ class SyncHandler: ) if push_rules_changed: + global_account_data = dict(global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1763,6 +1764,7 @@ class SyncHandler: account_data_by_room, ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) + global_account_data = dict(global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) -- cgit 1.5.1 From cf5233b783273efc84b991e7242fb4761ccc201a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Feb 2023 09:22:16 -0500 Subject: Avoid fetching unused account data in sync. (#14973) The per-room account data is no longer unconditionally fetched, even if all rooms will be filtered out. Global account data will not be fetched if it will all be filtered out. --- changelog.d/14973.misc | 1 + synapse/api/filtering.py | 30 +++++- synapse/handlers/account_data.py | 10 +- synapse/handlers/initial_sync.py | 5 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/sync.py | 88 +++++++++-------- synapse/rest/admin/users.py | 3 +- synapse/storage/databases/main/account_data.py | 127 ++++++++++++++++++------- 8 files changed, 176 insertions(+), 90 deletions(-) create mode 100644 changelog.d/14973.misc (limited to 'synapse') diff --git a/changelog.d/14973.misc b/changelog.d/14973.misc new file mode 100644 index 0000000000..3657623602 --- /dev/null +++ b/changelog.d/14973.misc @@ -0,0 +1 @@ +Improve performance of `/sync` in a few situations. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 83c42fc25a..b9f432cc23 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -219,9 +219,13 @@ class FilterCollection: self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {})) self._room_state_filter = Filter(hs, room_filter_json.get("state", {})) self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {})) - self._room_account_data = Filter(hs, room_filter_json.get("account_data", {})) + self._room_account_data_filter = Filter( + hs, room_filter_json.get("account_data", {}) + ) self._presence_filter = Filter(hs, filter_json.get("presence", {})) - self._account_data = Filter(hs, filter_json.get("account_data", {})) + self._global_account_data_filter = Filter( + hs, filter_json.get("account_data", {}) + ) self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) @@ -256,8 +260,10 @@ class FilterCollection: ) -> List[UserPresenceState]: return await self._presence_filter.filter(presence_states) - async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return await self._account_data.filter(events) + async def filter_global_account_data( + self, events: Iterable[JsonDict] + ) -> List[JsonDict]: + return await self._global_account_data_filter.filter(events) async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: return await self._room_state_filter.filter( @@ -279,7 +285,7 @@ class FilterCollection: async def filter_room_account_data( self, events: Iterable[JsonDict] ) -> List[JsonDict]: - return await self._room_account_data.filter( + return await self._room_account_data_filter.filter( await self._room_filter.filter(events) ) @@ -292,6 +298,13 @@ class FilterCollection: or self._presence_filter.filters_all_senders() ) + def blocks_all_global_account_data(self) -> bool: + """True if all global acount data will be filtered out.""" + return ( + self._global_account_data_filter.filters_all_types() + or self._global_account_data_filter.filters_all_senders() + ) + def blocks_all_room_ephemeral(self) -> bool: return ( self._room_ephemeral_filter.filters_all_types() @@ -299,6 +312,13 @@ class FilterCollection: or self._room_ephemeral_filter.filters_all_rooms() ) + def blocks_all_room_account_data(self) -> bool: + return ( + self._room_account_data_filter.filters_all_types() + or self._room_account_data_filter.filters_all_senders() + or self._room_account_data_filter.filters_all_rooms() + ) + def blocks_all_room_timeline(self) -> bool: return ( self._room_timeline_filter.filters_all_types() diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 67e789eef7..797de46dbc 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -343,10 +343,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]): } ) - ( - account_data, - room_account_data, - ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) + account_data = await self.store.get_updated_global_account_data_for_user( + user_id, last_stream_id + ) + room_account_data = await self.store.get_updated_room_account_data_for_user( + user_id, last_stream_id + ) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 191529bd8e..1a29abde98 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -154,9 +154,8 @@ class InitialSyncHandler: tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = await self.store.get_account_data_for_user( - user_id - ) + account_data = await self.store.get_global_account_data_for_user(user_id) + account_data_by_room = await self.store.get_room_account_data_for_user(user_id) public_room_ids = await self.store.get_public_room_ids() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d236cc09b5..6e7141d2ef 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -484,7 +484,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): user_id: The user's ID. """ # Retrieve user account data for predecessor room - user_account_data, _ = await self.store.get_account_data_for_user(user_id) + user_account_data = await self.store.get_global_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {}) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 202b35eee6..399685e5b7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1444,9 +1444,9 @@ class SyncHandler: logger.debug("Fetching account data") - account_data_by_room = await self._generate_sync_entry_for_account_data( - sync_result_builder - ) + # Global account data is included if it is not filtered out. + if not sync_config.filter_collection.blocks_all_global_account_data(): + await self._generate_sync_entry_for_account_data(sync_result_builder) # Presence data is included if the server has it enabled and not filtered out. include_presence_data = bool( @@ -1472,9 +1472,7 @@ class SyncHandler: ( newly_joined_rooms, newly_left_rooms, - ) = await self._generate_sync_entry_for_rooms( - sync_result_builder, account_data_by_room - ) + ) = await self._generate_sync_entry_for_rooms(sync_result_builder) # Work out which users have joined or left rooms we're in. We use this # to build the presence and device_list parts of the sync response in @@ -1717,35 +1715,29 @@ class SyncHandler: async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" - ) -> Dict[str, Dict[str, JsonDict]]: - """Generates the account data portion of the sync response. + ) -> None: + """Generates the global account data portion of the sync response. Account data (called "Client Config" in the spec) can be set either globally or for a specific room. Account data consists of a list of events which accumulate state, much like a room. - This function retrieves global and per-room account data. The former is written - to the given `sync_result_builder`. The latter is returned directly, to be - later written to the `sync_result_builder` on a room-by-room basis. + This function retrieves global account data and writes it to the given + `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling + of per-room account data. Args: sync_result_builder - - Returns: - A dictionary whose keys (room ids) map to the per room account data for that - room. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - # TODO Do not fetch room account data if it will be unused. - ( - global_account_data, - account_data_by_room, - ) = await self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key + global_account_data = ( + await self.store.get_updated_global_account_data_for_user( + user_id, since_token.account_data_key + ) ) push_rules_changed = await self.store.have_push_rules_changed_for_user( @@ -1758,28 +1750,26 @@ class SyncHandler: sync_config.user ) else: - # TODO Do not fetch room account data if it will be unused. - ( - global_account_data, - account_data_by_room, - ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) + all_global_account_data = await self.store.get_global_account_data_for_user( + user_id + ) - global_account_data = dict(global_account_data) + global_account_data = dict(all_global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) - account_data_for_user = await sync_config.filter_collection.filter_account_data( - [ - {"type": account_data_type, "content": content} - for account_data_type, content in global_account_data.items() - ] + account_data_for_user = ( + await sync_config.filter_collection.filter_global_account_data( + [ + {"type": account_data_type, "content": content} + for account_data_type, content in global_account_data.items() + ] + ) ) sync_result_builder.account_data = account_data_for_user - return account_data_by_room - async def _generate_sync_entry_for_presence( self, sync_result_builder: "SyncResultBuilder", @@ -1839,9 +1829,7 @@ class SyncHandler: sync_result_builder.presence = presence async def _generate_sync_entry_for_rooms( - self, - sync_result_builder: "SyncResultBuilder", - account_data_by_room: Dict[str, Dict[str, JsonDict]], + self, sync_result_builder: "SyncResultBuilder" ) -> Tuple[AbstractSet[str], AbstractSet[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1852,7 +1840,6 @@ class SyncHandler: Args: sync_result_builder - account_data_by_room: Dictionary of per room account data Returns: Returns a 2-tuple describing rooms the user has joined or left. @@ -1865,9 +1852,30 @@ class SyncHandler: since_token = sync_result_builder.since_token user_id = sync_result_builder.sync_config.user.to_string() + blocks_all_rooms = ( + sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + ) + + # 0. Start by fetching room account data (if required). + if ( + blocks_all_rooms + or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data() + ): + account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {} + elif since_token and not sync_result_builder.full_state: + account_data_by_room = ( + await self.store.get_updated_room_account_data_for_user( + user_id, since_token.account_data_key + ) + ) + else: + account_data_by_room = await self.store.get_room_account_data_for_user( + user_id + ) + # 1. Start by fetching all ephemeral events in rooms we've joined (if required). block_all_room_ephemeral = ( - sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + blocks_all_rooms or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) if block_all_room_ephemeral: @@ -2294,7 +2302,7 @@ class SyncHandler: room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], - account_data: Dict[str, JsonDict], + account_data: Mapping[str, JsonDict], always_include: bool = False, ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index b9dca8ef3a..0c0bf540b9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1192,7 +1192,8 @@ class AccountDataRestServlet(RestServlet): if not await self._store.get_user_by_id(user_id): raise NotFoundError("User not found") - global_data, by_room_data = await self._store.get_account_data_for_user(user_id) + global_data = await self._store.get_global_account_data_for_user(user_id) + by_room_data = await self._store.get_room_account_data_for_user(user_id) return HTTPStatus.OK, { "account_data": { "global": global_data, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 8a359d7eb8..2d6f02c14f 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -21,6 +21,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Tuple, cast, @@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return self._account_data_id_gen.get_current_token() @cached() - async def get_account_data_for_user( + async def get_global_account_data_for_user( self, user_id: str - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: + ) -> Mapping[str, JsonDict]: """ - Get all the client account_data for a user. + Get all the global client account_data for a user. If experimental MSC3391 support is enabled, any entries with an empty content body are excluded; as this means they have been deleted. Args: user_id: The user to get the account_data for. + Returns: - A 2-tuple of a dict of global account_data and a dict mapping from - room_id string to per room account_data dicts. + The global account_data. """ - def get_account_data_for_user_txn( + def get_global_account_data_for_user( txn: LoggingTransaction, - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: + ) -> Dict[str, JsonDict]: # The 'content != '{}' condition below prevents us from using # `simple_select_list_txn` here, as it doesn't support conditions # other than 'equals'. @@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn.execute(sql, (user_id,)) rows = self.db_pool.cursor_to_dict(txn) - global_account_data = { + return { row["account_data_type"]: db_to_json(row["content"]) for row in rows } + return await self.db_pool.runInteraction( + "get_global_account_data_for_user", get_global_account_data_for_user + ) + + @cached() + async def get_room_account_data_for_user( + self, user_id: str + ) -> Mapping[str, Mapping[str, JsonDict]]: + """ + Get all of the per-room client account_data for a user. + + If experimental MSC3391 support is enabled, any entries with an empty + content body are excluded; as this means they have been deleted. + + Args: + user_id: The user to get the account_data for. + + Returns: + A dict mapping from room_id string to per-room account_data dicts. + """ + + def get_room_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, JsonDict]]: # The 'content != '{}' condition below prevents us from using # `simple_select_list_txn` here, as it doesn't support conditions # other than 'equals'. @@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) room_data[row["account_data_type"]] = db_to_json(row["content"]) - return global_account_data, by_room + return by_room return await self.db_pool.runInteraction( - "get_account_data_for_user", get_account_data_for_user_txn + "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn ) @cached(num_args=2, max_entries=5000, tree=True) @@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) "get_updated_room_account_data", get_updated_room_account_data_txn ) - async def get_updated_account_data_for_user( + async def get_updated_global_account_data_for_user( self, user_id: str, stream_id: int - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - """Get all the client account_data for a that's changed for a user + ) -> Dict[str, JsonDict]: + """Get all the global account_data that's changed for a user. Args: user_id: The user to get the account_data for. stream_id: The point in the stream since which to get updates + Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. + A dict of global account_data. """ - def get_updated_account_data_for_user_txn( + def get_updated_global_account_data_for_user( txn: LoggingTransaction, - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - sql = ( - "SELECT account_data_type, content FROM account_data" - " WHERE user_id = ? AND stream_id > ?" - ) - + ) -> Dict[str, JsonDict]: + sql = """ + SELECT account_data_type, content FROM account_data + WHERE user_id = ? AND stream_id > ? + """ txn.execute(sql, (user_id, stream_id)) - global_account_data = {row[0]: db_to_json(row[1]) for row in txn} + return {row[0]: db_to_json(row[1]) for row in txn} - sql = ( - "SELECT room_id, account_data_type, content FROM room_account_data" - " WHERE user_id = ? AND stream_id > ?" - ) + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) + if not changed: + return {} + + return await self.db_pool.runInteraction( + "get_updated_global_account_data_for_user", + get_updated_global_account_data_for_user, + ) + + async def get_updated_room_account_data_for_user( + self, user_id: str, stream_id: int + ) -> Dict[str, Dict[str, JsonDict]]: + """Get all the room account_data that's changed for a user. + Args: + user_id: The user to get the account_data for. + stream_id: The point in the stream since which to get updates + + Returns: + A dict mapping from room_id string to per room account_data dicts. + """ + + def get_updated_room_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, JsonDict]]: + sql = """ + SELECT room_id, account_data_type, content FROM room_account_data + WHERE user_id = ? AND stream_id > ? + """ txn.execute(sql, (user_id, stream_id)) account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} @@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) - return global_account_data, account_data_by_room + return account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: - return {}, {} + return {} return await self.db_pool.runInteraction( - "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + "get_updated_room_account_data_for_user", + get_updated_room_account_data_for_user_txn, ) @cached(max_entries=5000, iterable=True) @@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self.get_global_account_data_by_type_for_user.invalidate( (row.user_id, row.data_type) ) - self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_global_account_data_for_user.invalidate((row.user_id,)) + self.get_room_account_data_for_user.invalidate((row.user_id,)) self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( (row.user_id, row.room_id, row.data_type) @@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_room_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content @@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return None self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_room_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), {} @@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( (user_id, account_data_type) ) @@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return None self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.prefill( (user_id, account_data_type), {} ) @@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn, self.get_account_data_for_room_and_type, (user_id,) ) self._invalidate_cache_and_stream( - txn, self.get_account_data_for_user, (user_id,) + txn, self.get_global_account_data_for_user, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_room_account_data_for_user, (user_id,) ) self._invalidate_cache_and_stream( txn, self.get_global_account_data_by_type_for_user, (user_id,) -- cgit 1.5.1 From 14be78d492fc31e743e9e5855ddb8b4c9520985a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Feb 2023 12:37:07 -0500 Subject: Support for MSC3758: exact_event_match push condition (#14964) This specifies to search for an exact value match, instead of string globbing. It only works across non-compound JSON values (null, boolean, integer, and strings). --- changelog.d/14964.feature | 1 + rust/benches/evaluator.rs | 65 +++++++++++--- rust/src/push/evaluator.rs | 69 +++++++++++---- rust/src/push/mod.rs | 83 +++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 7 +- synapse/config/experimental.py | 5 ++ synapse/push/bulk_push_rule_evaluator.py | 18 ++-- synapse/types/__init__.py | 2 + tests/push/test_push_rule_evaluator.py | 147 ++++++++++++++++++++++++++++++- 9 files changed, 356 insertions(+), 41 deletions(-) create mode 100644 changelog.d/14964.feature (limited to 'synapse') diff --git a/changelog.d/14964.feature b/changelog.d/14964.feature new file mode 100644 index 0000000000..13c0bc193b --- /dev/null +++ b/changelog.d/14964.feature @@ -0,0 +1 @@ +Implement the experimental `exact_event_match` push rule condition from [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758). diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 35f7a50bce..229553ebf8 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -16,6 +16,7 @@ use std::collections::BTreeSet; use synapse::push::{ evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, + SimpleJsonValue, }; use test::Bencher; @@ -24,9 +25,18 @@ extern crate test; #[bench] fn bench_match_exact(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -43,6 +53,7 @@ fn bench_match_exact(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -63,9 +74,18 @@ fn bench_match_exact(b: &mut Bencher) { #[bench] fn bench_match_word(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -82,6 +102,7 @@ fn bench_match_word(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -102,9 +123,18 @@ fn bench_match_word(b: &mut Bencher) { #[bench] fn bench_match_word_miss(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -121,6 +151,7 @@ fn bench_match_word_miss(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -141,9 +172,18 @@ fn bench_match_word_miss(b: &mut Bencher) { #[bench] fn bench_eval_message(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -160,6 +200,7 @@ fn bench_eval_message(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index ec7a8c4453..dd6b4343ec 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -22,8 +22,8 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, - Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, - RelatedEventMatchCondition, + Action, Condition, EventMatchCondition, ExactEventMatchCondition, FilteredPushRules, + KnownCondition, RelatedEventMatchCondition, SimpleJsonValue, }; lazy_static! { @@ -61,9 +61,9 @@ impl RoomVersionFeatures { /// Allows running a set of push rules against a particular event. #[pyclass] pub struct PushRuleEvaluator { - /// A mapping of "flattened" keys to string values in the event, e.g. + /// A mapping of "flattened" keys to simple JSON values in the event, e.g. /// includes things like "type" and "content.msgtype". - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, /// The "content.body", if any. body: String, @@ -87,7 +87,7 @@ pub struct PushRuleEvaluator { /// The related events, indexed by relation type. Flattened in the same manner as /// `flattened_keys`. - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, @@ -98,6 +98,9 @@ pub struct PushRuleEvaluator { /// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, + + /// If MSC3758 (exact_event_match push rule condition) is enabled. + msc3758_exact_event_match: bool, } #[pymethods] @@ -106,22 +109,23 @@ impl PushRuleEvaluator { #[allow(clippy::too_many_arguments)] #[new] pub fn py_new( - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, related_event_match_enabled: bool, room_version_feature_flags: Vec, msc3931_enabled: bool, + msc3758_exact_event_match: bool, ) -> Result { - let body = flattened_keys - .get("content.body") - .cloned() - .unwrap_or_default(); + let body = match flattened_keys.get("content.body") { + Some(SimpleJsonValue::Str(s)) => s.clone(), + _ => String::new(), + }; Ok(PushRuleEvaluator { flattened_keys, @@ -136,6 +140,7 @@ impl PushRuleEvaluator { related_event_match_enabled, room_version_feature_flags, msc3931_enabled, + msc3758_exact_event_match, }) } @@ -252,6 +257,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::ExactEventMatch(exact_event_match) => { + self.match_exact_event_match(exact_event_match)? + } KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } @@ -337,7 +345,9 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) { + let haystack = if let Some(SimpleJsonValue::Str(haystack)) = + self.flattened_keys.get(&*event_match.key) + { haystack } else { return Ok(false); @@ -355,6 +365,27 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `exact_event_match` condition. (MSC3758) + fn match_exact_event_match( + &self, + exact_event_match: &ExactEventMatchCondition, + ) -> Result { + // First check if the feature is enabled. + if !self.msc3758_exact_event_match { + return Ok(false); + } + + let value = &exact_event_match.value; + + let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) { + haystack + } else { + return Ok(false); + }; + + Ok(haystack == &**value) + } + /// Evaluates a `related_event_match` condition. (MSC3664) fn match_related_event_match( &self, @@ -410,7 +441,7 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(haystack) = event.get(&**key) { + let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) { haystack } else { return Ok(false); @@ -455,7 +486,10 @@ impl PushRuleEvaluator { #[test] fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); - flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + flattened_keys.insert( + "content.body".to_string(), + SimpleJsonValue::Str("foo bar bob hello".to_string()), + ); let evaluator = PushRuleEvaluator::py_new( flattened_keys, false, @@ -468,6 +502,7 @@ fn push_rule_evaluator() { true, vec![], true, + true, ) .unwrap(); @@ -482,7 +517,10 @@ fn test_requires_room_version_supports_condition() { use crate::push::{PushRule, PushRules}; let mut flattened_keys = BTreeMap::new(); - flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + flattened_keys.insert( + "content.body".to_string(), + SimpleJsonValue::Str("foo bar bob hello".to_string()), + ); let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, @@ -496,6 +534,7 @@ fn test_requires_room_version_supports_condition() { false, flags, true, + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 3c4f876cab..79e519fe11 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -56,7 +56,9 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{Context, Error}; use log::warn; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; +use pyo3::types::{PyBool, PyLong, PyString}; use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; @@ -248,6 +250,36 @@ impl<'de> Deserialize<'de> for Action { } } +/// A simple JSON values (string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum SimpleJsonValue { + Str(String), + Int(i64), + Bool(bool), + Null, +} + +impl<'source> FromPyObject<'source> for SimpleJsonValue { + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok(s) = ::try_from(ob) { + Ok(SimpleJsonValue::Str(s.to_string())) + // A bool *is* an int, ensure we try bool first. + } else if let Ok(b) = ::try_from(ob) { + Ok(SimpleJsonValue::Bool(b.extract()?)) + } else if let Ok(i) = ::try_from(ob) { + Ok(SimpleJsonValue::Int(i.extract()?)) + } else if ob.is_none() { + Ok(SimpleJsonValue::Null) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to SimpleJsonValue", + ob.get_type().name()? + ))) + } + } +} + /// A condition used in push rules to match against an event. /// /// We need this split as `serde` doesn't give us the ability to have a @@ -267,6 +299,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "com.beeper.msc3758.exact_event_match")] + ExactEventMatch(ExactEventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] @@ -309,6 +343,13 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::ExactEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ExactEventMatchCondition { + pub key: Cow<'static, str>, + pub value: Cow<'static, SimpleJsonValue>, +} + /// The body of a [`Condition::RelatedEventMatch`] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct RelatedEventMatchCondition { @@ -542,6 +583,48 @@ fn test_deserialize_unstable_msc3931_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3758_condition() { + // A string condition should work. + let json = + r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":"foo"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); + + // A boolean condition should work. + let json = + r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":true}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); + + // An integer condition should work. + let json = r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":1}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); + + // A null condition should work + let json = + r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":null}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); +} + #[test] fn test_deserialize_unstable_msc3952_user_condition() { let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 754acab2f9..328f681a29 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -14,7 +14,7 @@ from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union -from synapse.types import JsonDict +from synapse.types import JsonDict, SimpleJsonValue class PushRule: @property @@ -56,17 +56,18 @@ def get_base_rule_ids() -> Collection[str]: ... class PushRuleEvaluator: def __init__( self, - flattened_keys: Mapping[str, str], + flattened_keys: Mapping[str, SimpleJsonValue], has_mentions: bool, user_mentions: Set[str], room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - related_events_flattened: Mapping[str, Mapping[str, str]], + related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]], related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, + msc3758_exact_event_match: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 5e3a889081..6ac2f0c10d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -169,6 +169,11 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + # MSC3758: exact_event_match push rule condition + self.msc3758_exact_event_match = experimental.get( + "msc3758_exact_event_match", False + ) + # MSC3873: Disambiguate event_match keys. self.msc3783_escape_event_match_key = experimental.get( "msc3783_escape_event_match_key", False diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 39d2f88f03..8568aca528 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -43,6 +43,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types import SimpleJsonValue from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func @@ -256,13 +257,15 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + async def _related_events( + self, event: EventBase + ) -> Dict[str, Dict[str, SimpleJsonValue]]: """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: Mapping of relation type to flattened events. """ - related_events: Dict[str, Dict[str, str]] = {} + related_events: Dict[str, Dict[str, SimpleJsonValue]] = {} if self._related_event_match_enabled: related_event_id = event.content.get("m.relates_to", {}).get("event_id") relation_type = event.content.get("m.relates_to", {}).get("rel_type") @@ -425,6 +428,7 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag + self.hs.config.experimental.msc3758_exact_event_match, ) users = rules_by_user.keys() @@ -501,15 +505,15 @@ StateGroup = Union[object, int] def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, + result: Optional[Dict[str, SimpleJsonValue]] = None, *, msc3783_escape_event_match_key: bool = False, -) -> Dict[str, str]: +) -> Dict[str, SimpleJsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, flatten it into a single layer dictionary by combining the keys & sub-keys. - Any (non-dictionary), non-string value is dropped. + String, integer, boolean, and null values are kept. All others are dropped. Transforms: @@ -538,8 +542,8 @@ def _flatten_dict( # nested fields. key = key.replace("\\", "\\\\").replace(".", "\\.") - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() + if isinstance(value, (bool, str)) or type(value) is int or value is None: + result[".".join(prefix + [key])] = value elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below _flatten_dict( diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index f82d1cfc29..52e366c8ae 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -69,6 +69,8 @@ StateMap = Mapping[StateKey, T] MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. +# A "simple" (canonical) JSON value. +SimpleJsonValue = Optional[Union[str, int, bool]] # A JSON-serialisable dict. JsonDict = Dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 516b65cc3c..6603447341 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -57,7 +57,7 @@ class FlattenDictTestCase(unittest.TestCase): ) def test_non_string(self) -> None: - """Non-string items are dropped.""" + """Booleans, ints, and nulls should be kept while other items are dropped.""" input: Dict[str, Any] = { "woo": "woo", "foo": True, @@ -66,7 +66,9 @@ class FlattenDictTestCase(unittest.TestCase): "fuzz": [], "boo": {}, } - self.assertEqual({"woo": "woo"}, _flatten_dict(input)) + self.assertEqual( + {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input) + ) def test_event(self) -> None: """Events can also be flattened.""" @@ -86,9 +88,9 @@ class FlattenDictTestCase(unittest.TestCase): ) expected = { "content.msgtype": "m.text", - "content.body": "hello world!", + "content.body": "Hello world!", "content.format": "org.matrix.custom.html", - "content.formatted_body": "

hello world!

", + "content.formatted_body": "

Hello world!

", "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", @@ -166,6 +168,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): related_event_match_enabled=True, room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, + msc3758_exact_event_match=True, ) def test_display_name(self) -> None: @@ -410,6 +413,142 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should not match before a newline", ) + def test_exact_event_match_string(self) -> None: + """Check that exact_event_match conditions work as expected for strings.""" + + # Test against a string value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": "foobaz"}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": "FoobaZ"}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "test foobaz test"}, + "values must exactly match", + ) + value: Any + for value in (True, False, 1, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "foobaz"}), + "values should match on frozendicts", + ) + + def test_exact_event_match_boolean(self) -> None: + """Check that exact_event_match conditions work as expected for booleans.""" + + # Test against a True boolean value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": True, + } + self._assert_matches( + condition, + {"value": True}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": False}, + "incorrect values should not match", + ) + for value in ("foobaz", 1, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + # Test against a False boolean value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": False, + } + self._assert_matches( + condition, + {"value": False}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": True}, + "incorrect values should not match", + ) + # Choose false-y values to ensure there's no type coercion. + for value in ("", 0, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_match_null(self) -> None: + """Check that exact_event_match conditions work as expected for null.""" + + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": None, + } + self._assert_matches( + condition, + {"value": None}, + "exact value should match", + ) + for value in ("foobaz", True, False, 1, 1.1, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_match_integer(self) -> None: + """Check that exact_event_match conditions work as expected for integers.""" + + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": 1, + } + self._assert_matches( + condition, + {"value": 1}, + "exact value should match", + ) + value: Any + for value in (1.1, -1, 0): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect values should not match", + ) + for value in ("1", True, False, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) -- cgit 1.5.1 From d0c713cc85f094c323b2ba3f02d8ac411a7f0705 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 10 Feb 2023 23:29:00 +0000 Subject: Return read-only collections from `@cached` methods (#13755) It's important that collections returned from `@cached` methods are not modified, otherwise future retrievals from the cache will return the modified collection. This applies to the return values from `@cached` methods and the values inside the dictionaries returned by `@cachedList` methods. It's not necessary for the dictionaries returned by `@cachedList` methods themselves to be read-only. Signed-off-by: Sean Quah Co-authored-by: David Robertson --- changelog.d/13755.misc | 1 + synapse/app/phone_stats_home.py | 4 ++-- synapse/config/room_directory.py | 6 +++--- synapse/events/builder.py | 6 +++--- synapse/federation/federation_server.py | 3 ++- synapse/handlers/directory.py | 6 +++--- synapse/handlers/receipts.py | 4 ++-- synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 4 ++-- synapse/push/bulk_push_rule_evaluator.py | 4 ++-- synapse/state/__init__.py | 2 +- synapse/storage/controllers/state.py | 6 +++--- synapse/storage/databases/main/account_data.py | 2 +- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/devices.py | 17 +++++++++------ synapse/storage/databases/main/directory.py | 4 ++-- synapse/storage/databases/main/end_to_end_keys.py | 25 +++++++++++++--------- synapse/storage/databases/main/event_federation.py | 11 ++++++---- .../storage/databases/main/monthly_active_users.py | 4 ++-- synapse/storage/databases/main/receipts.py | 10 +++++---- synapse/storage/databases/main/registration.py | 4 ++-- synapse/storage/databases/main/relations.py | 7 ++++-- synapse/storage/databases/main/roommember.py | 19 ++++++++-------- synapse/storage/databases/main/signatures.py | 6 +++--- synapse/storage/databases/main/tags.py | 8 ++++--- synapse/storage/databases/main/user_directory.py | 4 ++-- tests/rest/admin/test_server_notice.py | 4 ++-- 27 files changed, 98 insertions(+), 77 deletions(-) create mode 100644 changelog.d/13755.misc (limited to 'synapse') diff --git a/changelog.d/13755.misc b/changelog.d/13755.misc new file mode 100644 index 0000000000..662ee00e99 --- /dev/null +++ b/changelog.d/13755.misc @@ -0,0 +1 @@ +Re-type hint some collections as read-only. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 53db1e85b3..897dd3edac 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,7 +15,7 @@ import logging import math import resource import sys -from typing import TYPE_CHECKING, List, Sized, Tuple +from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple from prometheus_client import Gauge @@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None: @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 - current_mau_count_by_service = {} + current_mau_count_by_service: Mapping[str, int] = {} reserved_users: Sized = () store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 3ed236217f..8666c22f01 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Collection from matrix_common.regex import glob_to_regex @@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config): return False def is_publishing_room_allowed( - self, user_id: str, room_id: str, aliases: List[str] + self, user_id: str, room_id: str, aliases: Collection[str] ) -> bool: """Checks if the given user is allowed to publish the room @@ -122,7 +122,7 @@ class _RoomDirectoryRule: except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: + def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 94dd1298e1..c82745275f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey @@ -103,7 +103,7 @@ class EventBuilder: async def build( self, - prev_event_ids: List[str], + prev_event_ids: Collection[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: @@ -136,7 +136,7 @@ class EventBuilder: format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8d36172484..6addc0bb65 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -23,6 +23,7 @@ from typing import ( Collection, Dict, List, + Mapping, Optional, Tuple, Union, @@ -1512,7 +1513,7 @@ class FederationHandlerRegistry: def _get_event_ids_for_partial_state_join( join_event: EventBase, prev_state_ids: StateMap[str], - summary: Dict[str, MemberSummary], + summary: Mapping[str, MemberSummary], ) -> Collection[str]: """Calculate state to be returned in a partial_state send_join diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index d31b0fbb17..a5798e9483 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -14,7 +14,7 @@ import logging import string -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence from typing_extensions import Literal @@ -486,7 +486,7 @@ class DirectoryHandler: ) if canonical_alias: # Ensure we do not mutate room_aliases. - room_aliases = room_aliases + [canonical_alias] + room_aliases = list(room_aliases) + [canonical_alias] if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases @@ -529,7 +529,7 @@ class DirectoryHandler: async def get_aliases_for_room( self, requester: Requester, room_id: str - ) -> List[str]: + ) -> Sequence[str]: """ Get a list of the aliases that currently point to this room on this server """ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 04c61ae3dd..2bacdebfb5 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService @@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): @staticmethod def filter_out_private_receipts( - rooms: List[JsonDict], user_id: str + rooms: Sequence[JsonDict], user_id: str ) -> List[JsonDict]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0e759b8a5d..060bbcb181 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1928,6 +1928,6 @@ class RoomShutdownHandler: return { "kicked_users": kicked_users, "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, + "local_aliases": list(aliases_for_room), "new_room_id": new_room_id, } diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 399685e5b7..4bae46158a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1519,7 +1519,7 @@ class SyncHandler: one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = ( + unused_fallback_key_types = list( await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) @@ -2301,7 +2301,7 @@ class SyncHandler: sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[Dict[str, Dict[str, Any]]], + tags: Optional[Mapping[str, Mapping[str, Any]]], account_data: Mapping[str, JsonDict], always_include: bool = False, ) -> None: diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8568aca528..f6a5bffb0f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,6 +22,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -149,7 +150,7 @@ class BulkPushRuleEvaluator: # little, we can skip fetching a huge number of push rules in large rooms. # This helps make joins and leaves faster. if event.type == EventTypes.Member: - local_users = [] + local_users: Sequence[str] = [] # We never notify a user about their own actions. This is enforced in # `_action_for_event_by_user` in the loop over `rules_by_user`, but we # do the same check here to avoid unnecessary DB queries. @@ -184,7 +185,6 @@ class BulkPushRuleEvaluator: if event.type == EventTypes.Member and event.membership == Membership.INVITE: invited = event.state_key if invited and self.hs.is_mine_id(invited) and invited not in local_users: - local_users = list(local_users) local_users.append(invited) if not local_users: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index e877e6f1a1..4dc25df67e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -226,7 +226,7 @@ class StateHandler: return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: List[str] + self, room_id: str, latest_event_ids: Collection[str] ) -> Set[str]: """ Get the users IDs who are currently in a room. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 52efd4a171..9d7a8a792f 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,6 +14,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -23,7 +24,6 @@ from typing import ( List, Mapping, Optional, - Set, Tuple, ) @@ -527,7 +527,7 @@ class StateStorageController: ) return state_map.get(key) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms @@ -584,7 +584,7 @@ class StateStorageController: async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """ Get the current users in the room with their profiles. If the room is currently partial-stated, this will block until the room has diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 2d6f02c14f..95567826f2 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get all the client account_data for a user for a room. Args: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5fb152c4ff..484db175d0 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): room_id: str, app_service: "ApplicationService", cache_context: _CacheContext, - ) -> List[str]: + ) -> Sequence[str]: """ Get all users in a room that the appservice controls. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 85c1778a81..1ca66d57d4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -21,6 +21,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + async def count_devices_by_users( + self, user_ids: Optional[Collection[str]] = None + ) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ def count_devices_by_users_txn( - txn: LoggingTransaction, user_ids: List[str] + txn: LoggingTransaction, user_ids: Collection[str] ) -> int: sql = """ SELECT count(*) @@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cancellable async def get_user_devices_from_cache( self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] - ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: @@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. - results: Dict[str, Dict[str, JsonDict]] = {} + results: Dict[str, Mapping[str, JsonDict]] = {} for user_id in user_ids: if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) # Then fetch all device-specific requests, but skip users we've already # fetched all devices for. + device_specific_results: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_id in user_and_device_ids: if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device + device_specific_results.setdefault(user_id, {})[device_id] = device + results.update(device_specific_results) set_tag("in_cache", str(results)) set_tag("not_in_cache", str(user_ids_not_in_cache)) @@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 5903fdaf00..44aa181174 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple import attr @@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): ) @cached(max_entries=5000) - async def get_aliases_for_room(self, room_id: str) -> List[str]: + async def get_aliases_for_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c4ac6c33ba..752dc16e17 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -20,7 +20,9 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, Union, cast, @@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str - ) -> List[str]: + ) -> Sequence[str]: """Returns the fallback key types that have an unused key. Args: @@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) def _get_bare_e2e_cross_signing_keys_bulk_txn( self, @@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures", - self._get_e2e_cross_signing_signatures_txn, - result, - from_user_id, + result = cast( + Dict[str, Optional[Mapping[str, JsonDict]]], + await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ), ) return result diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index bbee02ab18..ca780cca36 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -22,6 +22,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, cast, @@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: + async def get_max_depth_of( + self, event_ids: Collection[str] + ) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, @@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @cancellable async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @cached(max_entries=5000, num_args=2) async def _get_forward_extremeties_for_room( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index db9a24db5e..4b1061e6d7 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - async def get_monthly_active_count_by_service(self) -> Dict[str, int]: + async def get_monthly_active_count_by_service(self) -> Mapping[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 29972d5204..dddf49c2d5 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -21,7 +21,9 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, cast, ) @@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> Sequence[JsonDict]: """Get receipts for a single room for sending to clients. Args: @@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[JsonDict]: + ) -> Sequence[JsonDict]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: @@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, List[JsonDict]]: + ) -> Dict[str, Sequence[JsonDict]]: if not room_ids: return {} @@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 31f0f2bd3d..9a55e17624 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import attr @@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0018d6f7ab..fa3266c081 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -22,6 +22,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: + ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore): return result is not None @cached() - async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + async def get_aggregation_groups_for_event( + self, event_id: str + ) -> Sequence[JsonDict]: raise NotImplementedError() @cachedList( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ea6a5e2f34..694a5b802c 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -24,6 +24,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self._known_servers_count @cached(max_entries=100000, iterable=True) - async def get_users_in_room(self, room_id: str) -> List[str]: + async def get_users_in_room(self, room_id: str) -> Sequence[str]: """Returns a list of users in the room. Will return inaccurate results for rooms with partial state, since the state for @@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() - def get_user_in_room_with_profile( - self, room_id: str, user_id: str - ) -> Dict[str, ProfileInfo]: + def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo: raise NotImplementedError() @cachedList( @@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for all users in a given room. The profile information comes directly from this room's `m.room.member` @@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000) - async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: + async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: @@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached() async def get_invited_rooms_for_local_user( self, user_id: str - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(iterable=True) - async def get_local_users_in_room(self, room_id: str) -> List[str]: + async def get_local_users_in_room(self, room_id: str) -> Sequence[str]: """ Retrieves a list of the current roommembers who are local to the server. """ @@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) - user_who_share_room = set() + user_who_share_room: Set[str] = set() for room_id in room_ids: user_ids = await self.get_users_in_room(room_id) user_who_share_room.update(user_ids) @@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state.""" # First we check if we already have `get_users_in_room` in the cache, as diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 05da15074a..5dcb1fc0b5 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Tuple +from typing import Collection, Dict, List, Mapping, Tuple from unpaddedbase64 import encode_base64 @@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList class SignatureWorkerStore(EventsWorkerStore): @cached() - def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: + def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]: # This is a dummy function to allow get_event_reference_hashes # to use its cache raise NotImplementedError() @@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore): ) async def get_event_reference_hashes( self, event_ids: Collection[str] - ) -> Dict[str, Dict[str, bytes]]: + ) -> Mapping[str, Mapping[str, bytes]]: """Get all hashes for given events. Args: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index d5500cdd47..c149a9eacb 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.tcp.streams import AccountDataStream @@ -32,7 +32,9 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: + async def get_tags_for_user( + self, user_id: str + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for a user. @@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for the rooms where the tags have changed since the given version diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 14ef5b040d..f6a6fd4079 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,9 +16,9 @@ import logging import re from typing import ( TYPE_CHECKING, - Dict, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: + async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index a2f347f666..f71ff46d87 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Check invite and room membership status of a user. Args -- cgit 1.5.1 From 6cddf24e361fe43f086307c833cd814dc03363b6 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Sat, 11 Feb 2023 00:31:05 +0100 Subject: Faster joins: don't stall when a user joins during a fast join (#14606) Fixes #12801. Complement tests are at https://github.com/matrix-org/complement/pull/567. Avoid blocking on full state when handling a subsequent join into a partial state room. Also always perform a remote join into partial state rooms, since we do not know whether the joining user has been banned and want to avoid leaking history to banned users. Signed-off-by: Mathieu Velten Co-authored-by: Sean Quah Co-authored-by: David Robertson --- changelog.d/14606.misc | 1 + synapse/api/errors.py | 22 ++++++ synapse/federation/federation_server.py | 2 +- synapse/handlers/event_auth.py | 16 ++--- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 59 ++++++++++++++-- synapse/handlers/message.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 118 ++++++++++++++++++++++--------- synapse/handlers/room_member_worker.py | 5 +- synapse/storage/databases/main/events.py | 21 +----- tests/handlers/test_federation.py | 40 +++++------ 12 files changed, 196 insertions(+), 94 deletions(-) create mode 100644 changelog.d/14606.misc (limited to 'synapse') diff --git a/changelog.d/14606.misc b/changelog.d/14606.misc new file mode 100644 index 0000000000..e2debc96d8 --- /dev/null +++ b/changelog.d/14606.misc @@ -0,0 +1 @@ +Faster joins: don't stall when another user joins during a fast join resync. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c2c177fd71..9235ce6536 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -751,3 +751,25 @@ class ModuleFailedException(Exception): Raised when a module API callback fails, for example because it raised an exception. """ + + +class PartialStateConflictError(SynapseError): + """An internal error raised when attempting to persist an event with partial state + after the room containing the event has been un-partial stated. + + This error should be handled by recomputing the event context and trying again. + + This error has an HTTP status code so that it can be transported over replication. + It should not be exposed to clients. + """ + + @staticmethod + def message() -> str: + return "Cannot persist partial state event in un-partial stated room" + + def __init__(self) -> None: + super().__init__( + HTTPStatus.CONFLICT, + msg=PartialStateConflictError.message(), + errcode=Codes.UNKNOWN, + ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6addc0bb65..6d99845de5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -48,6 +48,7 @@ from synapse.api.errors import ( FederationError, IncompatibleRoomVersionError, NotFoundError, + PartialStateConflictError, SynapseError, UnsupportedRoomVersionError, ) @@ -81,7 +82,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index a23a8ce2a1..46dd63c3f0 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -202,7 +202,7 @@ class EventAuthHandler: state_ids: StateMap[str], room_version: RoomVersion, user_id: str, - prev_member_event: Optional[EventBase], + prev_membership: Optional[str], ) -> None: """ Check whether a user can join a room without an invite due to restricted join rules. @@ -214,15 +214,14 @@ class EventAuthHandler: state_ids: The state of the room as it currently is. room_version: The room version of the room being joined. user_id: The user joining the room. - prev_member_event: The current membership event for this user. + prev_membership: The current membership state for this user. `None` if the + user has never joined the room (equivalent to "leave"). Raises: AuthError if the user cannot join the room. """ # If the member is invited or currently joined, then nothing to do. - if prev_member_event and ( - prev_member_event.membership in (Membership.JOIN, Membership.INVITE) - ): + if prev_membership in (Membership.JOIN, Membership.INVITE): return # This is not a room with a restricted join rule, so we don't need to do the @@ -255,13 +254,14 @@ class EventAuthHandler: ) async def has_restricted_join_rules( - self, state_ids: StateMap[str], room_version: RoomVersion + self, partial_state_ids: StateMap[str], room_version: RoomVersion ) -> bool: """ Return if the room has the proper join rules set for access via rooms. Args: - state_ids: The state of the room as it currently is. + state_ids: The state of the room as it currently is. May be full or partial + state. room_version: The room version of the room to query. Returns: @@ -272,7 +272,7 @@ class EventAuthHandler: return False # If there's no join rule, then it defaults to invite (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None) if not join_rules_event_id: return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 43ed4a3dd1..08727e4857 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.api.errors import ( FederationPullAttemptBackoffError, HttpResponseException, NotFoundError, + PartialStateConflictError, RequestSendFailed, SynapseError, ) @@ -68,7 +69,6 @@ from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, ) -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3561f2f1de..b7136f8d1c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -47,6 +47,7 @@ from synapse.api.errors import ( FederationError, FederationPullAttemptBackoffError, HttpResponseException, + PartialStateConflictError, RequestSendFailed, SynapseError, ) @@ -74,7 +75,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) from synapse.state import StateResolutionStore -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( PersistedEventPosition, @@ -441,16 +441,17 @@ class FederationEventHandler: # Check if the user is already in the room or invited to the room. user_id = event.state_key prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None + prev_membership = None if prev_member_event_id: prev_member_event = await self._store.get_event(prev_member_event_id) + prev_membership = prev_member_event.membership # Check if the member should be allowed access via membership in a space. await self._event_auth_handler.check_restricted_join_rules( prev_state_ids, event.room_version, user_id, - prev_member_event, + prev_membership, ) @trace @@ -526,11 +527,57 @@ class FederationEventHandler: "Peristing join-via-remote %s (partial_state: %s)", event, partial_state ) with nested_logging_context(suffix=event.event_id): + if partial_state: + # When handling a second partial state join into a partial state room, + # the returned state will exclude the membership from the first join. To + # preserve prior memberships, we try to compute the partial state before + # the event ourselves if we know about any of the prev events. + # + # When we don't know about any of the prev events, it's fine to just use + # the returned state, since the new join will create a new forward + # extremity, and leave the forward extremity containing our prior + # memberships alone. + prev_event_ids = set(event.prev_event_ids()) + seen_event_ids = await self._store.have_events_in_timeline( + prev_event_ids + ) + missing_event_ids = prev_event_ids - seen_event_ids + + state_maps_to_resolve: List[StateMap[str]] = [] + + # Fetch the state after the prev events that we know about. + state_maps_to_resolve.extend( + ( + await self._state_storage_controller.get_state_groups_ids( + room_id, seen_event_ids, await_full_state=False + ) + ).values() + ) + + # When there are prev events we do not have the state for, we state + # resolve with the state returned by the remote homeserver. + if missing_event_ids or len(state_maps_to_resolve) == 0: + state_maps_to_resolve.append( + {(e.type, e.state_key): e.event_id for e in state} + ) + + state_ids_before_event = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version.identifier, + state_maps_to_resolve, + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) + ) + else: + state_ids_before_event = { + (e.type, e.state_key): e.event_id for e in state + } + context = await self._state_handler.compute_event_context( event, - state_ids_before_event={ - (e.type, e.state_key): e.event_id for e in state - }, + state_ids_before_event=state_ids_before_event, partial_state=partial_state, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3e30f52e4d..8f5b658d9d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( Codes, ConsentNotGivenError, NotFoundError, + PartialStateConflictError, ShadowBanError, SynapseError, UnstableSpecAuthError, @@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( MutableStateMap, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 060bbcb181..837dabb3b7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -43,6 +43,7 @@ from synapse.api.errors import ( Codes, LimitExceededError, NotFoundError, + PartialStateConflictError, StoreError, SynapseError, ) @@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.streams import EventSource from synapse.types import ( JsonDict, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6e7141d2ef..a965c7ec76 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,7 +26,13 @@ from synapse.api.constants import ( GuestAccess, Membership, ) -from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + PartialStateConflictError, + ShadowBanError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase @@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.types import ( JsonDict, Requester, @@ -56,6 +61,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class NoKnownServersError(SynapseError): + """No server already resident to the room was provided to the join/knock operation.""" + + def __init__(self, msg: str = "No known servers"): + super().__init__(404, msg) + + class RoomMemberHandler(metaclass=abc.ABCMeta): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level @@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: Room that we are trying to join user: User who is trying to join content: A dict that should be used as the content of the join event. + + Raises: + NoKnownServersError: if remote_room_hosts does not contain a server joined to + the room. """ raise NotImplementedError() @@ -823,14 +839,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): latest_event_ids = await self.store.get_prev_events_for_room(room_id) - state_before_join = await self.state_handler.compute_state_after_events( - room_id, latest_event_ids + is_partial_state_room = await self.store.is_partial_state_room(room_id) + partial_state_before_join = await self.state_handler.compute_state_after_events( + room_id, latest_event_ids, await_full_state=False ) + # `is_partial_state_room` also indicates whether `partial_state_before_join` is + # partial. # TODO: Refactor into dictionary of explicitly allowed transitions # between old and new state, with specific error messages for some # transitions and generic otherwise - old_state_id = state_before_join.get((EventTypes.Member, target.to_string())) + old_state_id = partial_state_before_join.get( + (EventTypes.Member, target.to_string()) + ) if old_state_id: old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None @@ -881,11 +902,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = await self._is_host_in_room(state_before_join) + is_host_in_room = await self._is_host_in_room(partial_state_before_join) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = await self._can_guest_join(state_before_join) + guest_can_join = await self._can_guest_join(partial_state_before_join) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -927,8 +948,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id, remote_room_hosts, content, + is_partial_state_room, is_host_in_room, - state_before_join, + partial_state_before_join, ) if remote_join: if ratelimit: @@ -1073,8 +1095,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: str, remote_room_hosts: List[str], content: JsonDict, + is_partial_state_room: bool, is_host_in_room: bool, - state_before_join: StateMap[str], + partial_state_before_join: StateMap[str], ) -> Tuple[bool, List[str]]: """ Check whether the server should do a remote join (as opposed to a local @@ -1093,9 +1116,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): remote_room_hosts: A list of remote room hosts. content: The content to use as the event body of the join. This may be modified. - is_host_in_room: True if the host is in the room. - state_before_join: The state before the join event (i.e. the resolution of - the states after its parent events). + is_partial_state_room: `True` if the server currently doesn't hold the full + state of the room. + is_host_in_room: `True` if the host is in the room. + partial_state_before_join: The state before the join event (i.e. the + resolution of the states after its parent events). May be full or + partial state, depending on `is_partial_state_room`. Returns: A tuple of: @@ -1109,6 +1135,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if not is_host_in_room: return True, remote_room_hosts + prev_member_event_id = partial_state_before_join.get( + (EventTypes.Member, user_id), None + ) + previous_membership = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + previous_membership = prev_member_event.membership + + # If we are not fully joined yet, and the target is not already in the room, + # let's do a remote join so another server with the full state can validate + # that the user has not been banned for example. + # We could just accept the join and wait for state res to resolve that later on + # but we would then leak room history to this person until then, which is pretty + # bad. + if is_partial_state_room and previous_membership != Membership.JOIN: + return True, remote_room_hosts + # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) @@ -1116,21 +1159,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If restricted join rules are not being used, a local join can always # be used. if not await self.event_auth_handler.has_restricted_join_rules( - state_before_join, room_version + partial_state_before_join, room_version ): return False, [] # If the user is invited to the room or already joined, the join # event can always be issued locally. - prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - if prev_member_event.membership in ( - Membership.JOIN, - Membership.INVITE, - ): - return False, [] + if previous_membership in (Membership.JOIN, Membership.INVITE): + return False, [] + + # All the partial state cases are covered above. We have been given the full + # state of the room. + assert not is_partial_state_room + state_before_join = partial_state_before_join # If the local host has a user who can issue invites, then a local # join can be done. @@ -1154,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Ensure the member should be allowed access via membership in a room. await self.event_auth_handler.check_restricted_join_rules( - state_before_join, room_version, user_id, prev_member_event + state_before_join, room_version, user_id, previous_membership ) # If this is going to be a local join, additional information must @@ -1304,11 +1345,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: + async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. + + Args: + partial_current_state_ids: The current state of the room. May be full or + partial state. """ - guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + guest_access_id = partial_current_state_ids.get( + (EventTypes.GuestAccess, ""), None + ) if not guest_access_id: return False @@ -1634,19 +1681,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) return event, stream_id - async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: + async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool: + """Returns whether the homeserver is in the room based on its current state. + + Args: + partial_current_state_ids: The current state of the room. May be full or + partial state. + """ # Have we just created the room, and is this about to be the very # first member event? - create_event_id = current_state_ids.get(("m.room.create", "")) - if len(current_state_ids) == 1 and create_event_id: + create_event_id = partial_current_state_ids.get(("m.room.create", "")) + if len(partial_current_state_ids) == 1 and create_event_id: # We can only get here if we're in the process of creating the room return True - for etype, state_key in current_state_ids: + for etype, state_key in partial_current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): continue - event_id = current_state_ids[(etype, state_key)] + event_id = partial_current_state_ids[(etype, state_key)] event = await self.store.get_event(event_id, allow_none=True) if not event: continue @@ -1715,8 +1768,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError( - 404, + raise NoKnownServersError( "Can't join remote room because no servers " "that are in the room have been provided.", ) @@ -1947,7 +1999,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise NoKnownServersError() return await self.federation_handler.do_knock( remote_room_hosts, room_id, user.to_string(), content=content diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 221552a2a6..ba261702d4 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -15,8 +15,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.errors import SynapseError -from synapse.handlers.room_member import RoomMemberHandler +from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler from synapse.replication.http.membership import ( ReplicationRemoteJoinRestServlet as ReplRemoteJoin, ReplicationRemoteKnockRestServlet as ReplRemoteKnock, @@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise NoKnownServersError() ret = await self._remote_join_client( requester=requester, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index cb66376fb4..ffe766fd56 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -16,7 +16,6 @@ import itertools import logging from collections import OrderedDict -from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -36,7 +35,7 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -72,24 +71,6 @@ event_counter = Counter( ) -class PartialStateConflictError(SynapseError): - """An internal error raised when attempting to persist an event with partial state - after the room containing the event has been un-partial stated. - - This error should be handled by recomputing the event context and trying again. - - This error has an HTTP status code so that it can be transported over replication. - It should not be exposed to clients. - """ - - def __init__(self) -> None: - super().__init__( - HTTPStatus.CONFLICT, - msg="Cannot persist partial state event in un-partial stated room", - errcode=Codes.UNKNOWN, - ) - - @attr.s(slots=True, auto_attribs=True) class DeltaState: """Deltas to use to update the `current_state_events` table. diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 57675fa407..5868eb2da7 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_client = fed_handler.federation_client room_id = "!room:example.com" - membership_event = make_event_from_dict( - { - "room_id": room_id, - "type": "m.room.member", - "sender": "@alice:test", - "state_key": "@alice:test", - "content": {"membership": "join"}, - }, - RoomVersions.V10, - ) - - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - "example.com", - membership_event, - RoomVersions.V10, - ) - ) - ) EVENT_CREATE = make_event_from_dict( { @@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): }, room_version=RoomVersions.V10, ) + membership_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.member", + "sender": "@alice:test", + "state_key": "@alice:test", + "content": {"membership": "join"}, + "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id], + }, + RoomVersions.V10, + ) + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + "example.com", + membership_event, + RoomVersions.V10, + ) + ) + ) mock_send_join = Mock( return_value=make_awaitable( SendJoinResult( -- cgit 1.5.1 From c10e13125057e506381d1be8c2ec1394eee45d62 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 13 Feb 2023 11:49:20 +0000 Subject: Apply logging from hotfixes branch to develop (#15054) * Apply logging from hotfixes branch to develop Part of #4826. Originally added in #11882. * Changelog --- changelog.d/15054.misc | 1 + synapse/rest/client/account.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/15054.misc (limited to 'synapse') diff --git a/changelog.d/15054.misc b/changelog.d/15054.misc new file mode 100644 index 0000000000..d800b107cf --- /dev/null +++ b/changelog.d/15054.misc @@ -0,0 +1 @@ +Merge debug logging from the hotfixes branch. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 4373c73662..232f3a976d 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -415,6 +415,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): request, MsisdnRequestTokenBody ) msisdn = phone_number_to_msisdn(body.country, body.phone_number) + logger.info("Request #%s to verify ownership of %s", body.send_attempt, msisdn) if not await check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -444,6 +445,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} + logger.info("MSISDN %s is already in use by %s", msisdn, existing_user_id) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.registration.account_threepid_delegate_msisdn: @@ -468,6 +470,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( body.send_attempt ) + logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret) return 200, ret -- cgit 1.5.1 From bdccfd24773d7482ae497263634312640dab01d1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 13 Feb 2023 12:12:48 +0000 Subject: Refactor arguments of `try_unbind_threepid(_with_id_server)` from dict to separate args (#15053) --- changelog.d/15053.misc | 1 + synapse/handlers/auth.py | 5 ++-- synapse/handlers/deactivate_account.py | 7 +---- synapse/handlers/identity.py | 47 +++++++++++++++++----------------- synapse/rest/client/account.py | 7 +---- 5 files changed, 28 insertions(+), 39 deletions(-) create mode 100644 changelog.d/15053.misc (limited to 'synapse') diff --git a/changelog.d/15053.misc b/changelog.d/15053.misc new file mode 100644 index 0000000000..c27528f5c6 --- /dev/null +++ b/changelog.d/15053.misc @@ -0,0 +1 @@ +Refactor arguments of `try_unbind_threepid` and `_try_unbind_threepid_with_id_server` to not use dictionaries. \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 30f2d46c3c..57a6854b1e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1593,9 +1593,8 @@ class AuthHandler: if medium == "email": address = canonicalise_email(address) - identity_handler = self.hs.get_identity_handler() - result = await identity_handler.try_unbind_threepid( - user_id, {"medium": medium, "address": address, "id_server": id_server} + result = await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server ) await self.store.user_delete_threepid(user_id, medium, address) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index d74d135c0c..d24f649382 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -106,12 +106,7 @@ class DeactivateAccountHandler: for threepid in threepids: try: result = await self._identity_handler.try_unbind_threepid( - user_id, - { - "medium": threepid["medium"], - "address": threepid["address"], - "id_server": id_server, - }, + user_id, threepid["medium"], threepid["address"], id_server ) identity_server_supports_unbinding &= result except Exception: diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 848e46eb9b..bf0f7acf80 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -219,28 +219,31 @@ class IdentityHandler: data = json_decoder.decode(e.msg) # XXX WAT? return data - async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool: - """Attempt to remove a 3PID from an identity server, or if one is not provided, all - identity servers we're aware the binding is present on + async def try_unbind_threepid( + self, mxid: str, medium: str, address: str, id_server: Optional[str] + ) -> bool: + """Attempt to remove a 3PID from one or more identity servers. Args: mxid: Matrix user ID of binding to be removed - threepid: Dict with medium & address of binding to be - removed, and an optional id_server. + medium: The medium of the third-party ID. + address: The address of the third-party ID. + id_server: An identity server to attempt to unbind from. If None, + attempt to remove the association from all identity servers + known to potentially have it. Raises: - SynapseError: If we failed to contact the identity server + SynapseError: If we failed to contact one or more identity servers. Returns: - True on success, otherwise False if the identity - server doesn't support unbinding (or no identity server found to - contact). + True on success, otherwise False if the identity server doesn't + support unbinding (or no identity server to contact was found). """ - if threepid.get("id_server"): - id_servers = [threepid["id_server"]] + if id_server: + id_servers = [id_server] else: id_servers = await self.store.get_id_servers_user_bound( - user_id=mxid, medium=threepid["medium"], address=threepid["address"] + mxid, medium, address ) # We don't know where to unbind, so we don't have a choice but to return @@ -249,20 +252,21 @@ class IdentityHandler: changed = True for id_server in id_servers: - changed &= await self.try_unbind_threepid_with_id_server( - mxid, threepid, id_server + changed &= await self._try_unbind_threepid_with_id_server( + mxid, medium, address, id_server ) return changed - async def try_unbind_threepid_with_id_server( - self, mxid: str, threepid: dict, id_server: str + async def _try_unbind_threepid_with_id_server( + self, mxid: str, medium: str, address: str, id_server: str ) -> bool: """Removes a binding from an identity server Args: mxid: Matrix user ID of binding to be removed - threepid: Dict with medium & address of binding to be removed + medium: The medium of the third-party ID + address: The address of the third-party ID id_server: Identity server to unbind from Raises: @@ -286,7 +290,7 @@ class IdentityHandler: content = { "mxid": mxid, - "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, + "threepid": {"medium": medium, "address": address}, } # we abuse the federation http client to sign the request, but we have to send it @@ -319,12 +323,7 @@ class IdentityHandler: except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - await self.store.remove_user_bound_threepid( - user_id=mxid, - medium=threepid["medium"], - address=threepid["address"], - id_server=id_server, - ) + await self.store.remove_user_bound_threepid(mxid, medium, address, id_server) return changed diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 232f3a976d..662f5bf762 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -737,12 +737,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past result = await self.identity_handler.try_unbind_threepid( - requester.user.to_string(), - { - "address": body.address, - "medium": body.medium, - "id_server": body.id_server, - }, + requester.user.to_string(), body.medium, body.address, body.id_server ) return 200, {"id_server_unbind_result": "success" if result else "no-support"} -- cgit 1.5.1 From 3d7aead5d62e6da97e006199b3f957325e54b053 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 13 Feb 2023 16:30:58 +0000 Subject: Tweak comment on `_is_local_room_accessible` as part of room visibility in `/hierarchy` to clarify the condition for a room being visible. (#14834) --- changelog.d/14834.misc | 1 + synapse/handlers/room_summary.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14834.misc (limited to 'synapse') diff --git a/changelog.d/14834.misc b/changelog.d/14834.misc new file mode 100644 index 0000000000..e683212dc4 --- /dev/null +++ b/changelog.d/14834.misc @@ -0,0 +1 @@ +Tweak comment on `_is_local_room_accessible` as part of room visibility in `/hierarchy` to clarify the condition for a room being visible. \ No newline at end of file diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4472019fbc..807245160d 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -521,8 +521,8 @@ class RoomSummaryHandler: It should return true if: - * The requester is joined or can join the room (per MSC3173). - * The origin server has any user that is joined or can join the room. + * The requesting user is joined or can join the room (per MSC3173); or + * The origin server has any user that is joined or can join the room; or * The history visibility is set to world readable. Args: -- cgit 1.5.1 From db2b105d69fa331bb3f050df82266314f61577ea Mon Sep 17 00:00:00 2001 From: Harishankar Kumar <31770598+hari01584@users.noreply.github.com> Date: Tue, 14 Feb 2023 15:07:08 +0530 Subject: Change collection[str] to StrCollection in event_auth code (#14929) Signed-off-by: Harishankar Kumar --- changelog.d/14929.misc | 1 + synapse/event_auth.py | 23 +++++++++------------- synapse/events/__init__.py | 6 +++--- synapse/storage/databases/main/events.py | 7 +++---- .../storage/databases/main/events_bg_updates.py | 6 +++--- 5 files changed, 19 insertions(+), 24 deletions(-) create mode 100644 changelog.d/14929.misc (limited to 'synapse') diff --git a/changelog.d/14929.misc b/changelog.d/14929.misc new file mode 100644 index 0000000000..2cc3614dfd --- /dev/null +++ b/changelog.d/14929.misc @@ -0,0 +1 @@ +Use `StrCollection` to avoid potential bugs with `Collection[str]`. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e0be9f88cc..4d6d1b8ebd 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -16,18 +16,7 @@ import collections.abc import logging import typing -from typing import ( - Any, - Collection, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -56,7 +45,13 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id +from synapse.types import ( + MutableStateMap, + StateMap, + StrCollection, + UserID, + get_domain_from_id, +) if typing.TYPE_CHECKING: # conditional imports to avoid import cycle @@ -69,7 +64,7 @@ logger = logging.getLogger(__name__) class _EventSourceStore(Protocol): async def get_events( self, - event_ids: Collection[str], + event_ids: StrCollection, redact_behaviour: EventRedactBehaviour, get_prev_content: bool = False, allow_rejected: bool = False, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8aca9a3ab9..91118a8d84 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -39,7 +39,7 @@ from unpaddedbase64 import encode_base64 from synapse.api.constants import RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.types import JsonDict, RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken, StrCollection from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze from synapse.util.stringutils import strtobool @@ -413,7 +413,7 @@ class EventBase(metaclass=abc.ABCMeta): """ return [e for e, _ in self._dict["prev_events"]] - def auth_event_ids(self) -> Sequence[str]: + def auth_event_ids(self) -> StrCollection: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. @@ -558,7 +558,7 @@ class FrozenEventV2(EventBase): """ return self._dict["prev_events"] - def auth_event_ids(self) -> Sequence[str]: + def auth_event_ids(self) -> StrCollection: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ffe766fd56..7996cbb557 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -25,7 +25,6 @@ from typing import ( Iterable, List, Optional, - Sequence, Set, Tuple, ) @@ -51,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.stringutils import non_null_str_or_none @@ -552,7 +551,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, Sequence[str]], + event_to_auth_chain: Dict[str, StrCollection], ) -> None: """Calculate the chain cover index for the given events. @@ -846,7 +845,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, Sequence[str]], + event_to_auth_chain: Dict[str, StrCollection], events_to_calc_chain_id_for: Set[str], chain_map: Dict[str, Tuple[int, int]], ) -> Dict[str, Tuple[int, int]]: diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index b9d3c36d60..584536111d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import attr @@ -29,7 +29,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection if TYPE_CHECKING: from synapse.server import HomeServer @@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self.event_chain_id_gen, # type: ignore[attr-defined] event_to_room_id, event_to_types, - cast(Dict[str, Sequence[str]], event_to_auth_chain), + cast(Dict[str, StrCollection], event_to_auth_chain), ) return _CalculateChainCover( -- cgit 1.5.1 From f09db5c9918b6aaeb1f53ab4fac3a7f05f512c5f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 Feb 2023 12:10:29 +0100 Subject: Skip calculating unread push actions in `/sync` when `enable_push` is false. (#14980) --- changelog.d/14980.misc | 1 + synapse/handlers/sync.py | 8 ++++++++ synapse/storage/databases/main/event_push_actions.py | 7 +++++++ 3 files changed, 16 insertions(+) create mode 100644 changelog.d/14980.misc (limited to 'synapse') diff --git a/changelog.d/14980.misc b/changelog.d/14980.misc new file mode 100644 index 0000000000..145f4a788b --- /dev/null +++ b/changelog.d/14980.misc @@ -0,0 +1 @@ +Skip calculating unread push actions in /sync when enable_push is false. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4bae46158a..3a9cddf15a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,6 +269,8 @@ class SyncHandler: self._state_storage_controller = self._storage_controllers.state self._device_handler = hs.get_device_handler() + self.should_calculate_push_rules = hs.config.push.enable_push + # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token # that sets 'since' to 'next_batch'), we know that device won't need a @@ -1288,6 +1290,12 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig ) -> RoomNotifCounts: + if not self.should_calculate_push_rules: + # If push rules have been universally disabled then we know we won't + # have any unread counts in the DB, so we may as well skip asking + # the DB. + return RoomNotifCounts.empty() + with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a0c370fde..eeccf5db24 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -203,11 +203,18 @@ class RoomNotifCounts: # Map of thread ID to the notification counts. threads: Dict[str, NotifCounts] + @staticmethod + def empty() -> "RoomNotifCounts": + return _EMPTY_ROOM_NOTIF_COUNTS + def __len__(self) -> int: # To properly account for the amount of space in any caches. return len(self.threads) + 1 +_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {}) + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: -- cgit 1.5.1 From cb262713b701d1abcbca03334d17e2d0f81eee4a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 Feb 2023 12:20:25 +0100 Subject: Fix clashing DB txn name (#15070) * Fix clashing DB txn name * Newsfile --- changelog.d/15070.misc | 1 + synapse/storage/databases/main/end_to_end_keys.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/15070.misc (limited to 'synapse') diff --git a/changelog.d/15070.misc b/changelog.d/15070.misc new file mode 100644 index 0000000000..0f3244de9f --- /dev/null +++ b/changelog.d/15070.misc @@ -0,0 +1 @@ +Fix clashing database transaction name. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 752dc16e17..2c2d145666 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -262,7 +262,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker for batch in batch_iter(signature_query, 50): cross_sigs_result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures", + "get_e2e_cross_signing_signatures_for_devices", self._get_e2e_cross_signing_signatures_for_devices_txn, batch, ) -- cgit 1.5.1 From 463c19ac3648b242c480e299349d2ef90bf38a0b Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 14 Feb 2023 12:32:19 +0000 Subject: Faster joins: Omit device list updates from partial state rooms in /sync (#15069) ...when lazy loading of members is not enabled. It's weird to notify a client that another user's device list has changed when the client doesn't think that they share a room. Note that when a room is un-partial stated, device list updates are emitted for every member in that room over /sync. Signed-off-by: Sean Quah --- changelog.d/15069.misc | 1 + synapse/handlers/sync.py | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 changelog.d/15069.misc (limited to 'synapse') diff --git a/changelog.d/15069.misc b/changelog.d/15069.misc new file mode 100644 index 0000000000..e7a619ad2b --- /dev/null +++ b/changelog.d/15069.misc @@ -0,0 +1 @@ +Faster joins: omit device list updates originating from partial state rooms in /sync responses without lazy loading of members enabled. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3a9cddf15a..4e4595312c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1399,6 +1399,11 @@ class SyncHandler: for room_id, is_partial_state in results.items() if is_partial_state ) + membership_change_events = [ + event + for event in membership_change_events + if not results.get(event.room_id, False) + ] # Incremental eager syncs should additionally include rooms that # - we are joined to -- cgit 1.5.1 From e9b1ff9f31f8ff093e7eaf9c54fa8f40a3b66aa8 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 14 Feb 2023 15:50:59 +0000 Subject: Prevent clients from reporting nonexistent events. (#13779) --- changelog.d/13779.bugfix | 1 + synapse/rest/client/report_event.py | 11 ++++++++++- tests/rest/client/test_report_event.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13779.bugfix (limited to 'synapse') diff --git a/changelog.d/13779.bugfix b/changelog.d/13779.bugfix new file mode 100644 index 0000000000..a92c722c6e --- /dev/null +++ b/changelog.d/13779.bugfix @@ -0,0 +1 @@ +Prevent clients from reporting nonexistent events. \ No newline at end of file diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index e2b410cf32..9be5860221 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -16,7 +16,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -39,6 +39,7 @@ class ReportEventRestServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastores().main + self._event_handler = self.hs.get_event_handler() async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str @@ -61,6 +62,14 @@ class ReportEventRestServlet(RestServlet): Codes.BAD_JSON, ) + event = await self._event_handler.get_event( + requester.user, room_id, event_id, show_redacted=False + ) + if event is None: + raise NotFoundError( + "Unable to report event: it does not exist or you aren't able to see it." + ) + await self.store.add_event_report( room_id=room_id, event_id=event_id, diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index 7cb1017a4a..1250685d39 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase): data = {"reason": None, "score": None} self._assert_status(400, data) + def test_cannot_report_nonexistent_event(self) -> None: + """ + Tests that we don't accept event reports for events which do not exist. + """ + channel = self.make_request( + "POST", + f"rooms/{self.room_id}/report/$nonsenseeventid:test", + {"reason": "i am very sad"}, + access_token=self.other_user_tok, + ) + self.assertEqual(404, channel.code, msg=channel.result["body"]) + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, data, access_token=self.other_user_tok -- cgit 1.5.1 From 119e0795a58548fb38fab299e7c362fcbb388d68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Feb 2023 14:02:19 -0500 Subject: Implement MSC3966: Add a push rule condition to search for a value in an array. (#15045) The `exact_event_property_contains` condition can be used to search for a value inside of an array. --- changelog.d/15045.feature | 1 + rust/benches/evaluator.rs | 32 +++++++++------- rust/src/push/evaluator.rs | 65 +++++++++++++++++++++++++------- rust/src/push/mod.rs | 33 +++++++++++++++- stubs/synapse/synapse_rust/push.pyi | 7 ++-- synapse/config/experimental.py | 5 +++ synapse/push/bulk_push_rule_evaluator.py | 21 +++++++---- synapse/types/__init__.py | 1 + tests/push/test_push_rule_evaluator.py | 53 ++++++++++++++++++++++++-- 9 files changed, 176 insertions(+), 42 deletions(-) create mode 100644 changelog.d/15045.feature (limited to 'synapse') diff --git a/changelog.d/15045.feature b/changelog.d/15045.feature new file mode 100644 index 0000000000..87766befda --- /dev/null +++ b/changelog.d/15045.feature @@ -0,0 +1 @@ +Experimental support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): the `exact_event_property_contains` push rule condition. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 229553ebf8..8213dfd9ea 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -15,8 +15,8 @@ #![feature(test)] use std::collections::BTreeSet; use synapse::push::{ - evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, - SimpleJsonValue, + evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue, + PushRules, SimpleJsonValue, }; use test::Bencher; @@ -27,15 +27,15 @@ fn bench_match_exact(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -54,6 +54,7 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -76,15 +77,15 @@ fn bench_match_word(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -103,6 +104,7 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -125,15 +127,15 @@ fn bench_match_word_miss(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -152,6 +154,7 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -174,15 +177,15 @@ fn bench_eval_message(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -201,6 +204,7 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index dd6b4343ec..2eaa06ad76 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -14,6 +14,7 @@ use std::collections::{BTreeMap, BTreeSet}; +use crate::push::JsonValue; use anyhow::{Context, Error}; use lazy_static::lazy_static; use log::warn; @@ -63,7 +64,7 @@ impl RoomVersionFeatures { pub struct PushRuleEvaluator { /// A mapping of "flattened" keys to simple JSON values in the event, e.g. /// includes things like "type" and "content.msgtype". - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, /// The "content.body", if any. body: String, @@ -87,7 +88,7 @@ pub struct PushRuleEvaluator { /// The related events, indexed by relation type. Flattened in the same manner as /// `flattened_keys`. - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, @@ -101,6 +102,9 @@ pub struct PushRuleEvaluator { /// If MSC3758 (exact_event_match push rule condition) is enabled. msc3758_exact_event_match: bool, + + /// If MSC3966 (exact_event_property_contains push rule condition) is enabled. + msc3966_exact_event_property_contains: bool, } #[pymethods] @@ -109,21 +113,22 @@ impl PushRuleEvaluator { #[allow(clippy::too_many_arguments)] #[new] pub fn py_new( - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, related_event_match_enabled: bool, room_version_feature_flags: Vec, msc3931_enabled: bool, msc3758_exact_event_match: bool, + msc3966_exact_event_property_contains: bool, ) -> Result { let body = match flattened_keys.get("content.body") { - Some(SimpleJsonValue::Str(s)) => s.clone(), + Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(), _ => String::new(), }; @@ -141,6 +146,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc3758_exact_event_match, + msc3966_exact_event_property_contains, }) } @@ -263,6 +269,9 @@ impl PushRuleEvaluator { KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } + KnownCondition::ExactEventPropertyContains(exact_event_match) => { + self.match_exact_event_property_contains(exact_event_match)? + } KnownCondition::IsUserMention => { if let Some(uid) = user_id { self.user_mentions.contains(uid) @@ -345,7 +354,7 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(SimpleJsonValue::Str(haystack)) = + let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = self.flattened_keys.get(&*event_match.key) { haystack @@ -377,7 +386,9 @@ impl PushRuleEvaluator { let value = &exact_event_match.value; - let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) { + let haystack = if let Some(JsonValue::Value(haystack)) = + self.flattened_keys.get(&*exact_event_match.key) + { haystack } else { return Ok(false); @@ -441,11 +452,12 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) { - haystack - } else { - return Ok(false); - }; + let haystack = + if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; // For the content.body we match against "words", but for everything // else we match against the entire value. @@ -459,6 +471,29 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `exact_event_property_contains` condition. (MSC3758) + fn match_exact_event_property_contains( + &self, + exact_event_match: &ExactEventMatchCondition, + ) -> Result { + // First check if the feature is enabled. + if !self.msc3966_exact_event_property_contains { + return Ok(false); + } + + let value = &exact_event_match.value; + + let haystack = if let Some(JsonValue::Array(haystack)) = + self.flattened_keys.get(&*exact_event_match.key) + { + haystack + } else { + return Ok(false); + }; + + Ok(haystack.contains(&**value)) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -488,7 +523,7 @@ fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert( "content.body".to_string(), - SimpleJsonValue::Str("foo bar bob hello".to_string()), + JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())), ); let evaluator = PushRuleEvaluator::py_new( flattened_keys, @@ -503,6 +538,7 @@ fn push_rule_evaluator() { vec![], true, true, + true, ) .unwrap(); @@ -519,7 +555,7 @@ fn test_requires_room_version_supports_condition() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert( "content.body".to_string(), - SimpleJsonValue::Str("foo bar bob hello".to_string()), + JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())), ); let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( @@ -535,6 +571,7 @@ fn test_requires_room_version_supports_condition() { flags, true, true, + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 79e519fe11..253b5f367c 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -58,7 +58,7 @@ use anyhow::{Context, Error}; use log::warn; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyLong, PyString}; +use pyo3::types::{PyBool, PyList, PyLong, PyString}; use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; @@ -280,6 +280,35 @@ impl<'source> FromPyObject<'source> for SimpleJsonValue { } } +/// A JSON values (list, string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum JsonValue { + Array(Vec), + Value(SimpleJsonValue), +} + +impl<'source> FromPyObject<'source> for JsonValue { + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok(l) = ::try_from(ob) { + match l.iter().map(SimpleJsonValue::extract).collect() { + Ok(a) => Ok(JsonValue::Array(a)), + Err(e) => Err(PyTypeError::new_err(format!( + "Can't convert to JsonValue::Array: {}", + e + ))), + } + } else if let Ok(v) = SimpleJsonValue::extract(ob) { + Ok(JsonValue::Value(v)) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to JsonValue", + ob.get_type().name()? + ))) + } + } +} + /// A condition used in push rules to match against an event. /// /// We need this split as `serde` doesn't give us the ability to have a @@ -303,6 +332,8 @@ pub enum KnownCondition { ExactEventMatch(ExactEventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3966.exact_event_property_contains")] + ExactEventPropertyContains(ExactEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] IsUserMention, #[serde(rename = "org.matrix.msc3952.is_room_mention")] diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 328f681a29..7b33c30cc9 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -14,7 +14,7 @@ from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union -from synapse.types import JsonDict, SimpleJsonValue +from synapse.types import JsonDict, JsonValue class PushRule: @property @@ -56,18 +56,19 @@ def get_base_rule_ids() -> Collection[str]: ... class PushRuleEvaluator: def __init__( self, - flattened_keys: Mapping[str, SimpleJsonValue], + flattened_keys: Mapping[str, JsonValue], has_mentions: bool, user_mentions: Set[str], room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]], + related_events_flattened: Mapping[str, Mapping[str, JsonValue]], related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, msc3758_exact_event_match: bool, + msc3966_exact_event_property_contains: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 6ac2f0c10d..1d294f8798 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -188,3 +188,8 @@ class ExperimentalConfig(Config): self.msc3958_supress_edit_notifs = experimental.get( "msc3958_supress_edit_notifs", False ) + + # MSC3966: exact_event_property_contains push rule condition. + self.msc3966_exact_event_property_contains = experimental.get( + "msc3966_exact_event_property_contains", False + ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f6a5bffb0f..2e917c90c4 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -44,7 +44,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator -from synapse.types import SimpleJsonValue +from synapse.types import JsonValue from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func @@ -259,13 +259,13 @@ class BulkPushRuleEvaluator: async def _related_events( self, event: EventBase - ) -> Dict[str, Dict[str, SimpleJsonValue]]: + ) -> Dict[str, Dict[str, JsonValue]]: """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: Mapping of relation type to flattened events. """ - related_events: Dict[str, Dict[str, SimpleJsonValue]] = {} + related_events: Dict[str, Dict[str, JsonValue]] = {} if self._related_event_match_enabled: related_event_id = event.content.get("m.relates_to", {}).get("event_id") relation_type = event.content.get("m.relates_to", {}).get("rel_type") @@ -429,6 +429,7 @@ class BulkPushRuleEvaluator: event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc3758_exact_event_match, + self.hs.config.experimental.msc3966_exact_event_property_contains, ) users = rules_by_user.keys() @@ -502,18 +503,22 @@ RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] +def _is_simple_value(value: Any) -> bool: + return isinstance(value, (bool, str)) or type(value) is int or value is None + + def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, - result: Optional[Dict[str, SimpleJsonValue]] = None, + result: Optional[Dict[str, JsonValue]] = None, *, msc3783_escape_event_match_key: bool = False, -) -> Dict[str, SimpleJsonValue]: +) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, flatten it into a single layer dictionary by combining the keys & sub-keys. - String, integer, boolean, and null values are kept. All others are dropped. + String, integer, boolean, null or lists of those values are kept. All others are dropped. Transforms: @@ -542,8 +547,10 @@ def _flatten_dict( # nested fields. key = key.replace("\\", "\\\\").replace(".", "\\.") - if isinstance(value, (bool, str)) or type(value) is int or value is None: + if _is_simple_value(value): result[".".join(prefix + [key])] = value + elif isinstance(value, (list, tuple)): + result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)] elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below _flatten_dict( diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 52e366c8ae..33363867c4 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -71,6 +71,7 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A "simple" (canonical) JSON value. SimpleJsonValue = Optional[Union[str, int, bool]] +JsonValue = Union[List[SimpleJsonValue], Tuple[SimpleJsonValue, ...], SimpleJsonValue] # A JSON-serialisable dict. JsonDict = Dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 6603447341..0554d247bc 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict, JsonMapping, UserID from synapse.util import Clock +from synapse.util.frozenutils import freeze from tests import unittest from tests.test_utils.event_injection import create_event, inject_member_event @@ -57,17 +58,24 @@ class FlattenDictTestCase(unittest.TestCase): ) def test_non_string(self) -> None: - """Booleans, ints, and nulls should be kept while other items are dropped.""" + """String, booleans, ints, nulls and list of those should be kept while other items are dropped.""" input: Dict[str, Any] = { "woo": "woo", "foo": True, "bar": 1, "baz": None, - "fuzz": [], + "fuzz": ["woo", True, 1, None, [], {}], "boo": {}, } self.assertEqual( - {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input) + { + "woo": "woo", + "foo": True, + "bar": 1, + "baz": None, + "fuzz": ["woo", True, 1, None], + }, + _flatten_dict(input), ) def test_event(self) -> None: @@ -117,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -128,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -169,6 +179,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, msc3758_exact_event_match=True, + msc3966_exact_event_property_contains=True, ) def test_display_name(self) -> None: @@ -549,6 +560,42 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "incorrect types should not match", ) + def test_exact_event_property_contains(self) -> None: + """Check that exact_event_property_contains conditions work as expected.""" + + condition = { + "kind": "org.matrix.msc3966.exact_event_property_contains", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": ["foobaz"]}, + "exact value should match", + ) + self._assert_matches( + condition, + {"value": ["foobaz", "bugz"]}, + "extra values should match", + ) + self._assert_not_matches( + condition, + {"value": ["FoobaZ"]}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "foobaz"}, + "does not search in a string", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + freeze({"value": ["foobaz"]}), + "values should match on frozendicts", + ) + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) -- cgit 1.5.1 From 06ba71083eefbe1fd9a8eeed10e541dd7b52796f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 14 Feb 2023 23:42:29 +0000 Subject: Fix order of partial state tables when purging (#15068) * Fix order of partial state tables when purging `partial_state_rooms` has an FK on `events` pointing to the join event we get from `/send_join`, so we must delete from that table before deleting from `events`. **NB:** It would be nice to cancel any resync processes for the room being purged. We do not do this at present. To do so reliably we'd need an internal HTTP "replication" endpoint, because the worker doing the resync process may be different to that handling the purge request. The first time the resync process tries to write data after the deletion it will fail because we have deleted necessary data e.g. auth events. AFAICS it will not retry the resync, so the only downside to not cancelling the resync is a scary-looking traceback. (This is presumably extremely race-sensitive.) * Changelog * admist(?) -> between * Warn about a race * Fix typo, thanks Sean Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --------- Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/15068.bugfix | 1 + synapse/handlers/federation.py | 5 +++++ synapse/storage/databases/main/purge_events.py | 6 ++++-- 3 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 changelog.d/15068.bugfix (limited to 'synapse') diff --git a/changelog.d/15068.bugfix b/changelog.d/15068.bugfix new file mode 100644 index 0000000000..f09ffa2877 --- /dev/null +++ b/changelog.d/15068.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.76.0 where partially-joined rooms could not be deleted using the [purge room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api). diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08727e4857..1d0f6bcd6f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1880,6 +1880,11 @@ class FederationHandler: logger.info("Updating current state for %s", room_id) # TODO(faster_joins): notify workers in notify_room_un_partial_stated # https://github.com/matrix-org/synapse/issues/12994 + # + # NB: there's a potential race here. If room is purged just before we + # call this, we _might_ end up inserting rows into current_state_events. + # (The logic is hard to chase through.) We think this is fine, but if + # not the HS admin should purge the room again. await self.state_handler.update_current_state(room_id) logger.info("Handling any pending device list updates") diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 9213ce0b5a..9c41d01e13 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -420,12 +420,14 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_push_actions", "event_search", "event_failed_pull_attempts", + # Note: the partial state tables have foreign keys between each other, and to + # `events` and `rooms`. We need to delete from them in the right order. "partial_state_events", + "partial_state_rooms_servers", + "partial_state_rooms", "events", "federation_inbound_events_staging", "local_current_membership", - "partial_state_rooms_servers", - "partial_state_rooms", "receipts_graph", "receipts_linearized", "room_aliases", -- cgit 1.5.1 From 5febf88b6c5194582f427142dc0850625547c0d9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 15 Feb 2023 11:47:57 +0000 Subject: Update the error code for duplicate annotation (#15075) --- changelog.d/15075.feature | 2 ++ synapse/api/errors.py | 4 ++++ synapse/handlers/message.py | 6 +++++- 3 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 changelog.d/15075.feature (limited to 'synapse') diff --git a/changelog.d/15075.feature b/changelog.d/15075.feature new file mode 100644 index 0000000000..d25a7567a4 --- /dev/null +++ b/changelog.d/15075.feature @@ -0,0 +1,2 @@ +Update the error code returned when user sends a duplicate annotation. + diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 9235ce6536..e1737de59b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -108,6 +108,10 @@ class Codes(str, Enum): USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL" + # Attempt to send a second annotation with the same event type & annotation key + # MSC2677 + DUPLICATE_ANNOTATION = "M_DUPLICATE_ANNOTATION" + class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f5b658d9d..aa90d0000d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1337,7 +1337,11 @@ class EventCreationHandler: relation.parent_id, event.type, aggregation_key, event.sender ) if already_exists: - raise SynapseError(400, "Can't send same reaction twice") + raise SynapseError( + 400, + "Can't send same reaction twice", + errcode=Codes.DUPLICATE_ANNOTATION, + ) # Don't attempt to start a thread if the parent event is a relation. elif relation.rel_type == RelationTypes.THREAD: -- cgit 1.5.1 From 27a3a72a50cb24f25e48fad1e6e79aba2cd1bea2 Mon Sep 17 00:00:00 2001 From: 999lakhisidhu <42063995+999lakhisidhu@users.noreply.github.com> Date: Wed, 15 Feb 2023 16:39:31 +0400 Subject: Support for selecting the Redis logical database. (#15034) Note that this is only used for key-value store (cached values) and not for the pub/sub replication used by Synapse. --- changelog.d/15034.feature | 1 + contrib/docker_compose_workers/README.md | 1 + docs/usage/configuration/config_documentation.md | 4 ++++ synapse/config/redis.py | 1 + synapse/server.py | 1 + 5 files changed, 8 insertions(+) create mode 100644 changelog.d/15034.feature (limited to 'synapse') diff --git a/changelog.d/15034.feature b/changelog.d/15034.feature new file mode 100644 index 0000000000..34f320da92 --- /dev/null +++ b/changelog.d/15034.feature @@ -0,0 +1 @@ +Allow Synapse to use a specific Redis [logical database](https://redis.io/commands/select/) in worker-mode deployments. diff --git a/contrib/docker_compose_workers/README.md b/contrib/docker_compose_workers/README.md index bdd3dd32e0..d3cdfe5614 100644 --- a/contrib/docker_compose_workers/README.md +++ b/contrib/docker_compose_workers/README.md @@ -68,6 +68,7 @@ redis: enabled: true host: redis port: 6379 + # dbid: # password: ``` diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 2883f76a26..75483bfb12 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3927,6 +3927,9 @@ This setting has the following sub-options: * `host` and `port`: Optional host and port to use to connect to redis. Defaults to localhost and 6379 * `password`: Optional password if configured on the Redis instance. +* `dbid`: Optional redis dbid if needs to connect to specific redis logical db. + + _Added in Synapse 1.78.0._ Example configuration: ```yaml @@ -3935,6 +3938,7 @@ redis: host: localhost port: 6379 password: + dbid: ``` --- ## Individual worker configuration diff --git a/synapse/config/redis.py b/synapse/config/redis.py index b42dd2e93a..e6a75be434 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -33,4 +33,5 @@ class RedisConfig(Config): self.redis_host = redis_config.get("host", "localhost") self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid", None) self.redis_password = redis_config.get("password") diff --git a/synapse/server.py b/synapse/server.py index efc6b5f895..e5a3475247 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -827,6 +827,7 @@ class HomeServer(metaclass=abc.ABCMeta): hs=self, host=self.config.redis.redis_host, port=self.config.redis.redis_port, + dbid=self.config.redis.redis_dbid, password=self.config.redis.redis_password, reconnect=True, ) -- cgit 1.5.1 From 3ad817bfe561e0b7ddcd8398a76a4a4d3d789138 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 15 Feb 2023 13:59:06 +0000 Subject: Fix federated joins when the first server in the list is not in the room (#15074) Previously we would give up upon receiving a 404 from the first server, instead of trying the rest of the servers in the list. Signed-off-by: Sean Quah --- changelog.d/15074.bugfix | 1 + synapse/federation/federation_client.py | 11 +++++------ 2 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 changelog.d/15074.bugfix (limited to 'synapse') diff --git a/changelog.d/15074.bugfix b/changelog.d/15074.bugfix new file mode 100644 index 0000000000..d1ceb4f4c8 --- /dev/null +++ b/changelog.d/15074.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where federated joins would fail if the first server in the list of servers to try is not in the room. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 0ac85a3be7..7d04560dca 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -884,7 +884,7 @@ class FederationClient(FederationBase): if 500 <= e.code < 600: failover = True - elif e.code == 400 and synapse_error.errcode in failover_errcodes: + elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes: failover = True elif failover_on_unknown_endpoint and self._is_unknown_endpoint( @@ -999,14 +999,13 @@ class FederationClient(FederationBase): return destination, ev, room_version + failover_errcodes = {Codes.NOT_FOUND} # MSC3083 defines additional error codes for room joins. Unfortunately # we do not yet know the room version, assume these will only be returned # by valid room versions. - failover_errcodes = ( - (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN) - if membership == Membership.JOIN - else None - ) + if membership == Membership.JOIN: + failover_errcodes.add(Codes.UNABLE_AUTHORISE_JOIN) + failover_errcodes.add(Codes.UNABLE_TO_GRANT_JOIN) return await self._try_destination_list( "make_" + membership, -- cgit 1.5.1 From 979f237b282cbdaab8d74cc4c7473117093d63d9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:51:22 -0500 Subject: Update intentional mentions (MSC3952) to depend on `exact_event_match` (MSC3758). (#15037) This replaces the specific `is_room_mention` push rule condition used in MSC3952 with the generic `exact_event_match` push rule condition from MSC3758. No functionality changes due to this. --- changelog.d/15037.misc | 1 + rust/benches/evaluator.rs | 4 ---- rust/src/push/base_rules.rs | 7 +++++-- rust/src/push/evaluator.rs | 7 ------- rust/src/push/mod.rs | 13 ------------- stubs/synapse/synapse_rust/push.pyi | 1 - synapse/config/experimental.py | 7 ++++--- synapse/push/bulk_push_rule_evaluator.py | 4 ---- tests/push/test_bulk_push_rule_evaluator.py | 18 ++++++++++++++++-- tests/push/test_push_rule_evaluator.py | 23 ----------------------- 10 files changed, 26 insertions(+), 59 deletions(-) create mode 100644 changelog.d/15037.misc (limited to 'synapse') diff --git a/changelog.d/15037.misc b/changelog.d/15037.misc new file mode 100644 index 0000000000..fabfe77d35 --- /dev/null +++ b/changelog.d/15037.misc @@ -0,0 +1 @@ +Update [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952) support based on changes to the MSC. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 8213dfd9ea..efd19a2165 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -45,7 +45,6 @@ fn bench_match_exact(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -95,7 +94,6 @@ fn bench_match_word(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -145,7 +143,6 @@ fn bench_match_word_miss(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -195,7 +192,6 @@ fn bench_eval_message(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index dcbca340fe..4a62b9696f 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -21,13 +21,13 @@ use lazy_static::lazy_static; use serde_json::Value; use super::KnownCondition; -use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; +use crate::push::{Action, ExactEventMatchCondition, SimpleJsonValue}; const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak { set_tweak: Cow::Borrowed("highlight"), @@ -168,7 +168,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mention"), priority_class: 5, conditions: Cow::Borrowed(&[ - Condition::Known(KnownCondition::IsRoomMention), + Condition::Known(KnownCondition::ExactEventMatch(ExactEventMatchCondition { + key: Cow::Borrowed("content.org.matrix.msc3952.mentions.room"), + value: Cow::Borrowed(&SimpleJsonValue::Bool(true)), + })), Condition::Known(KnownCondition::SenderNotificationPermission { key: Cow::Borrowed("room"), }), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 2eaa06ad76..55551ecb56 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -73,8 +73,6 @@ pub struct PushRuleEvaluator { has_mentions: bool, /// The user mentions that were part of the message. user_mentions: BTreeSet, - /// True if the message is a room message. - room_mention: bool, /// The number of users in the room. room_member_count: u64, @@ -116,7 +114,6 @@ impl PushRuleEvaluator { flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, - room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, @@ -137,7 +134,6 @@ impl PushRuleEvaluator { body, has_mentions, user_mentions, - room_mention, room_member_count, notification_power_levels, sender_power_level, @@ -279,7 +275,6 @@ impl PushRuleEvaluator { false } } - KnownCondition::IsRoomMention => self.room_mention, KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -529,7 +524,6 @@ fn push_rule_evaluator() { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), BTreeMap::new(), @@ -562,7 +556,6 @@ fn test_requires_room_version_supports_condition() { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), BTreeMap::new(), diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 253b5f367c..fdd2b2c143 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -336,8 +336,6 @@ pub enum KnownCondition { ExactEventPropertyContains(ExactEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] IsUserMention, - #[serde(rename = "org.matrix.msc3952.is_room_mention")] - IsRoomMention, ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -667,17 +665,6 @@ fn test_deserialize_unstable_msc3952_user_condition() { )); } -#[test] -fn test_deserialize_unstable_msc3952_room_condition() { - let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; - - let condition: Condition = serde_json::from_str(json).unwrap(); - assert!(matches!( - condition, - Condition::Known(KnownCondition::IsRoomMention) - )); -} - #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 7b33c30cc9..a8f0ed2435 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -59,7 +59,6 @@ class PushRuleEvaluator: flattened_keys: Mapping[str, JsonValue], has_mentions: bool, user_mentions: Set[str], - room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1d294f8798..54c91953e1 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -179,9 +179,10 @@ class ExperimentalConfig(Config): "msc3783_escape_event_match_key", False ) - # MSC3952: Intentional mentions - self.msc3952_intentional_mentions = experimental.get( - "msc3952_intentional_mentions", False + # MSC3952: Intentional mentions, this depends on MSC3758. + self.msc3952_intentional_mentions = ( + experimental.get("msc3952_intentional_mentions", False) + and self.msc3758_exact_event_match ) # MSC3959: Do not generate notifications for edits. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 2e917c90c4..5fc38431ba 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -400,7 +400,6 @@ class BulkPushRuleEvaluator: mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict) user_mentions: Set[str] = set() - room_mention = False if has_mentions: # mypy seems to have lost the type even though it must be a dict here. assert isinstance(mentions, dict) @@ -410,8 +409,6 @@ class BulkPushRuleEvaluator: user_mentions = set( filter(lambda item: isinstance(item, str), user_mentions_raw) ) - # Room mention is only true if the value is exactly true. - room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( _flatten_dict( @@ -420,7 +417,6 @@ class BulkPushRuleEvaluator: ), has_mentions, user_mentions, - room_mention, room_member_count, sender_power_level, notification_levels, diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 7567756135..199e3d7b70 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -227,7 +227,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) return len(result) > 0 - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3758_exact_event_match": True, + "msc3952_intentional_mentions": True, + } + } + ) def test_user_mentions(self) -> None: """Test the behavior of an event which includes invalid user mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -323,7 +330,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) ) - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3758_exact_event_match": True, + "msc3952_intentional_mentions": True, + } + } + ) def test_room_mentions(self) -> None: """Test the behavior of an event which includes invalid room mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 0554d247bc..d320a12f96 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -149,7 +149,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): *, has_mentions: bool = False, user_mentions: Optional[Set[str]] = None, - room_mention: bool = False, related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( @@ -170,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): _flatten_dict(event), has_mentions, user_mentions or set(), - room_mention, room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), @@ -232,27 +230,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions # since the BulkPushRuleEvaluator is what handles data sanitisation. - def test_room_mentions(self) -> None: - """Check for room mentions.""" - condition = {"kind": "org.matrix.msc3952.is_room_mention"} - - # No room mention shouldn't match. - evaluator = self._get_evaluator({}, has_mentions=True) - self.assertFalse(evaluator.matches(condition, None, None)) - - # Room mention should match. - evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True) - self.assertTrue(evaluator.matches(condition, None, None)) - - # A room mention and user mention is valid. - evaluator = self._get_evaluator( - {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True - ) - self.assertTrue(evaluator.matches(condition, None, None)) - - # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions - # since the BulkPushRuleEvaluator is what handles data sanitisation. - def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: -- cgit 1.5.1 From ffc2ee521d26f5b842df7902ade5de7a538e602d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 16 Feb 2023 16:09:11 +0000 Subject: Use mypy 1.0 (#15052) * Update mypy and mypy-zope * Remove unused ignores These used to suppress ``` synapse/storage/engines/__init__.py:28: error: "__new__" must return a class instance (got "NoReturn") [misc] ``` and ``` synapse/http/matrixfederationclient.py:1270: error: "BaseException" has no attribute "reasons" [attr-defined] ``` (note that we check `hasattr(e, "reasons")` above) * Avoid empty body warnings, sometimes by marking methods as abstract E.g. ``` tests/handlers/test_register.py:58: error: Missing return statement [empty-body] tests/handlers/test_register.py:108: error: Missing return statement [empty-body] ``` * Suppress false positive about `JaegerConfig` Complaint was ``` synapse/logging/opentracing.py:450: error: Function "Type[Config]" could always be true in boolean context [truthy-function] ``` * Fix not calling `is_state()` Oops! ``` tests/rest/client/test_third_party_rules.py:428: error: Function "Callable[[], bool]" could always be true in boolean context [truthy-function] ``` * Suppress false positives from ParamSpecs ```` synapse/logging/opentracing.py:971: error: Argument 2 to "_custom_sync_async_decorator" has incompatible type "Callable[[Arg(Callable[P, R], 'func'), **P], _GeneratorContextManager[None]]"; expected "Callable[[Callable[P, R], **P], _GeneratorContextManager[None]]" [arg-type] synapse/logging/opentracing.py:1017: error: Argument 2 to "_custom_sync_async_decorator" has incompatible type "Callable[[Arg(Callable[P, R], 'func'), **P], _GeneratorContextManager[None]]"; expected "Callable[[Callable[P, R], **P], _GeneratorContextManager[None]]" [arg-type] ```` * Drive-by improvement to `wrapping_logic` annotation * Workaround false "unreachable" positives See https://github.com/Shoobx/mypy-zope/issues/91 ``` tests/http/test_proxyagent.py:626: error: Statement is unreachable [unreachable] tests/http/test_proxyagent.py:762: error: Statement is unreachable [unreachable] tests/http/test_proxyagent.py:826: error: Statement is unreachable [unreachable] tests/http/test_proxyagent.py:838: error: Statement is unreachable [unreachable] tests/http/test_proxyagent.py:845: error: Statement is unreachable [unreachable] tests/http/federation/test_matrix_federation_agent.py:151: error: Statement is unreachable [unreachable] tests/http/federation/test_matrix_federation_agent.py:452: error: Statement is unreachable [unreachable] tests/logging/test_remote_handler.py:60: error: Statement is unreachable [unreachable] tests/logging/test_remote_handler.py:93: error: Statement is unreachable [unreachable] tests/logging/test_remote_handler.py:127: error: Statement is unreachable [unreachable] tests/logging/test_remote_handler.py:152: error: Statement is unreachable [unreachable] ``` * Changelog * Tweak DBAPI2 Protocol to be accepted by mypy 1.0 Some extra context in: - https://github.com/matrix-org/python-canonicaljson/pull/57 - https://github.com/python/mypy/issues/6002 - https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected * Pull in updated canonicaljson lib so the protocol check just works * Improve comments in opentracing I tried to workaround the ignores but found it too much trouble. I think the corresponding issue is https://github.com/python/mypy/issues/12909. The mypy repo has a PR claiming to fix this (https://github.com/python/mypy/pull/14677) which might mean this gets resolved soon? * Better annotation for INTERACTIVE_AUTH_CHECKERS * Drive-by AUTH_TYPE annotation, to remove an ignore --- changelog.d/15052.misc | 1 + poetry.lock | 69 ++++++++++---------- synapse/handlers/auth.py | 2 +- synapse/handlers/ui_auth/checkers.py | 18 ++++-- synapse/http/matrixfederationclient.py | 2 +- synapse/logging/opentracing.py | 24 +++++-- synapse/rest/media/v1/_base.py | 9 ++- synapse/storage/engines/__init__.py | 4 +- synapse/storage/types.py | 74 ++++++++++++++++++---- synapse/streams/__init__.py | 7 +- tests/handlers/test_register.py | 4 +- .../federation/test_matrix_federation_agent.py | 11 ++-- tests/http/test_proxyagent.py | 40 ++++++------ tests/logging/test_remote_handler.py | 17 ++--- tests/rest/client/test_auth.py | 3 + tests/rest/client/test_third_party_rules.py | 2 +- tests/utils.py | 26 +++++++- 17 files changed, 209 insertions(+), 104 deletions(-) create mode 100644 changelog.d/15052.misc (limited to 'synapse') diff --git a/changelog.d/15052.misc b/changelog.d/15052.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/15052.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/poetry.lock b/poetry.lock index e534b30d2b..eb1e3d797b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -146,14 +146,14 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] [[package]] name = "canonicaljson" -version = "1.6.4" +version = "1.6.5" description = "Canonical JSON" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "canonicaljson-1.6.4-py3-none-any.whl", hash = "sha256:55d282853b4245dbcd953fe54c39b91571813d7c44e1dbf66e3c4f97ff134a48"}, - {file = "canonicaljson-1.6.4.tar.gz", hash = "sha256:6c09b2119511f30eb1126cfcd973a10824e20f1cfd25039cde3d1218dd9c8d8f"}, + {file = "canonicaljson-1.6.5-py3-none-any.whl", hash = "sha256:806ea6f2cbb7405d20259e1c36dd1214ba5c242fa9165f5bd0bf2081f82c23fb"}, + {file = "canonicaljson-1.6.5.tar.gz", hash = "sha256:68dfc157b011e07d94bf74b5d4ccc01958584ed942d9dfd5fdd706609e81cd4b"}, ] [package.dependencies] @@ -1146,36 +1146,38 @@ files = [ [[package]] name = "mypy" -version = "0.981" +version = "1.0.0" description = "Optional static typing for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, - {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, - {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, - {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, - {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, - {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, - {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, - {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, - {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, - {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, - {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, - {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, - {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, - {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, - {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, - {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, - {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, - {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, - {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, - {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, + {file = "mypy-1.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0626db16705ab9f7fa6c249c017c887baf20738ce7f9129da162bb3075fc1af"}, + {file = "mypy-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1ace23f6bb4aec4604b86c4843276e8fa548d667dbbd0cb83a3ae14b18b2db6c"}, + {file = "mypy-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87edfaf344c9401942883fad030909116aa77b0fa7e6e8e1c5407e14549afe9a"}, + {file = "mypy-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0ab090d9240d6b4e99e1fa998c2d0aa5b29fc0fb06bd30e7ad6183c95fa07593"}, + {file = "mypy-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:7cc2c01dfc5a3cbddfa6c13f530ef3b95292f926329929001d45e124342cd6b7"}, + {file = "mypy-1.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14d776869a3e6c89c17eb943100f7868f677703c8a4e00b3803918f86aafbc52"}, + {file = "mypy-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb2782a036d9eb6b5a6efcdda0986774bf798beef86a62da86cb73e2a10b423d"}, + {file = "mypy-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cfca124f0ac6707747544c127880893ad72a656e136adc935c8600740b21ff5"}, + {file = "mypy-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8845125d0b7c57838a10fd8925b0f5f709d0e08568ce587cc862aacce453e3dd"}, + {file = "mypy-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b1b9e1ed40544ef486fa8ac022232ccc57109f379611633ede8e71630d07d2"}, + {file = "mypy-1.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c7cf862aef988b5fbaa17764ad1d21b4831436701c7d2b653156a9497d92c83c"}, + {file = "mypy-1.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd187d92b6939617f1168a4fe68f68add749902c010e66fe574c165c742ed88"}, + {file = "mypy-1.0.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4e5175026618c178dfba6188228b845b64131034ab3ba52acaffa8f6c361f805"}, + {file = "mypy-1.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2f6ac8c87e046dc18c7d1d7f6653a66787a4555085b056fe2d599f1f1a2a2d21"}, + {file = "mypy-1.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7306edca1c6f1b5fa0bc9aa645e6ac8393014fa82d0fa180d0ebc990ebe15964"}, + {file = "mypy-1.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3cfad08f16a9c6611e6143485a93de0e1e13f48cfb90bcad7d5fde1c0cec3d36"}, + {file = "mypy-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67cced7f15654710386e5c10b96608f1ee3d5c94ca1da5a2aad5889793a824c1"}, + {file = "mypy-1.0.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a86b794e8a56ada65c573183756eac8ac5b8d3d59daf9d5ebd72ecdbb7867a43"}, + {file = "mypy-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:50979d5efff8d4135d9db293c6cb2c42260e70fb010cbc697b1311a4d7a39ddb"}, + {file = "mypy-1.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ae4c7a99e5153496243146a3baf33b9beff714464ca386b5f62daad601d87af"}, + {file = "mypy-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e398652d005a198a7f3c132426b33c6b85d98aa7dc852137a2a3be8890c4072"}, + {file = "mypy-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be78077064d016bc1b639c2cbcc5be945b47b4261a4f4b7d8923f6c69c5c9457"}, + {file = "mypy-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92024447a339400ea00ac228369cd242e988dd775640755fa4ac0c126e49bb74"}, + {file = "mypy-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fe523fcbd52c05040c7bee370d66fee8373c5972171e4fbc323153433198592d"}, + {file = "mypy-1.0.0-py3-none-any.whl", hash = "sha256:2efa963bdddb27cb4a0d42545cd137a8d2b883bd181bbc4525b568ef6eca258f"}, + {file = "mypy-1.0.0.tar.gz", hash = "sha256:f34495079c8d9da05b183f9f7daec2878280c2ad7cc81da686ef0b484cea2ecf"}, ] [package.dependencies] @@ -1186,6 +1188,7 @@ typing-extensions = ">=3.10" [package.extras] dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] python2 = ["typed-ast (>=1.4.0,<2)"] reports = ["lxml"] @@ -1203,18 +1206,18 @@ files = [ [[package]] name = "mypy-zope" -version = "0.3.11" +version = "0.9.0" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" files = [ - {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, - {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, + {file = "mypy-zope-0.9.0.tar.gz", hash = "sha256:88bf6cd056e38b338e6956055958a7805b4ff84404ccd99e29883a3647a1aeb3"}, + {file = "mypy_zope-0.9.0-py3-none-any.whl", hash = "sha256:e1bb4b57084f76ff8a154a3e07880a1af2ac6536c491dad4b143d529f72c5d15"}, ] [package.dependencies] -mypy = "0.981" +mypy = "1.0.0" "zope.interface" = "*" "zope.schema" = "*" @@ -1705,7 +1708,7 @@ files = [ cffi = ">=1.4.1" [package.extras] -docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] +docs = ["sphinx (>=1.6.5)", "sphinx_rtd_theme"] tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] [[package]] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 57a6854b1e..cf12b55d21 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -201,7 +201,7 @@ class AuthHandler: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): - self.checkers[inst.AUTH_TYPE] = inst # type: ignore + self.checkers[inst.AUTH_TYPE] = inst self.bcrypt_rounds = hs.config.registration.bcrypt_rounds diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 332edcca24..78a75bfed6 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -13,7 +13,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type from twisted.web.client import PartialDownloadError @@ -27,19 +28,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class UserInteractiveAuthChecker: +class UserInteractiveAuthChecker(ABC): """Abstract base class for an interactive auth checker""" - def __init__(self, hs: "HomeServer"): + # This should really be an "abstract class property", i.e. it should + # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE. + # But calling this a `ClassVar` is simpler than a decorator stack of + # @property @abstractmethod and @classmethod (if that's even the right order). + AUTH_TYPE: ClassVar[str] + + def __init__(self, hs: "HomeServer"): # noqa: B027 pass + @abstractmethod def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: True if this login type is enabled. """ + raise NotImplementedError() + @abstractmethod async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step @@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): ) -INTERACTIVE_AUTH_CHECKERS = [ +INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b92f1d3d1a..312aab4dcc 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1267,7 +1267,7 @@ class MatrixFederationHttpClient: def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined] + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 6c7cf1b294..5aed71262f 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -188,7 +188,7 @@ from typing import ( ) import attr -from typing_extensions import ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.internet import defer from twisted.web.http import Request @@ -445,7 +445,7 @@ def init_tracer(hs: "HomeServer") -> None: opentracing = None # type: ignore[assignment] return - if not opentracing or not JaegerConfig: + if opentracing is None or JaegerConfig is None: raise ConfigError( "The server has been configured to use opentracing but opentracing is not " "installed." @@ -872,7 +872,7 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte def _custom_sync_async_decorator( func: Callable[P, R], - wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]], + wrapping_logic: Callable[Concatenate[Callable[P, R], P], ContextManager[None]], ) -> Callable[P, R]: """ Decorates a function that is sync or async (coroutines), or that returns a Twisted @@ -902,10 +902,14 @@ def _custom_sync_async_decorator( """ if inspect.iscoroutinefunction(func): - + # In this branch, R = Awaitable[RInner], for some other type RInner @wraps(func) - async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + async def _wrapper( + *args: P.args, **kwargs: P.kwargs + ) -> Any: # Return type is RInner with wrapping_logic(func, *args, **kwargs): + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. return await func(*args, **kwargs) # type: ignore[misc] else: @@ -972,7 +976,11 @@ def trace_with_opname( if not opentracing: return func - return _custom_sync_async_decorator(func, _wrapping_logic) + # type-ignore: mypy seems to be confused by the ParamSpecs here. + # I think the problem is https://github.com/python/mypy/issues/12909 + return _custom_sync_async_decorator( + func, _wrapping_logic # type: ignore[arg-type] + ) return _decorator @@ -1018,7 +1026,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield - return _custom_sync_async_decorator(func, _wrapping_logic) + # type-ignore: mypy seems to be confused by the ParamSpecs here. + # I think the problem is https://github.com/python/mypy/issues/12909 + return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type] @contextlib.contextmanager diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index d30878f704..6e035afcce 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -16,6 +16,7 @@ import logging import os import urllib +from abc import ABC, abstractmethod from types import TracebackType from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type @@ -284,13 +285,14 @@ async def respond_with_responder( finish_request(request) -class Responder: +class Responder(ABC): """Represents a response that can be streamed to the requester. Responder is a context manager which *must* be used, so that any resources held can be cleaned up. """ + @abstractmethod def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer @@ -300,11 +302,12 @@ class Responder: Returns: Resolves once the response has finished being written """ + raise NotImplementedError() - def __enter__(self) -> None: + def __enter__(self) -> None: # noqa: B027 pass - def __exit__( + def __exit__( # noqa: B027 self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a182e8a098..d1ccb7390a 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -25,7 +25,7 @@ try: except ImportError: class PostgresEngine(BaseDatabaseEngine): # type: ignore[no-redef] - def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] + def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise RuntimeError( f"Cannot create {cls.__name__} -- psycopg2 module is not installed" ) @@ -36,7 +36,7 @@ try: except ImportError: class Sqlite3Engine(BaseDatabaseEngine): # type: ignore[no-redef] - def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] + def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise RuntimeError( f"Cannot create {cls.__name__} -- sqlite3 module is not installed" ) diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 0031df1e06..56a0048539 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import TracebackType -from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from typing_extensions import Protocol @@ -112,15 +123,35 @@ class DBAPI2Module(Protocol): # extends from this hierarchy. See # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE - Warning: Type[Exception] - Error: Type[Exception] + # + # Note: rather than + # x: T + # we write + # @property + # def x(self) -> T: ... + # which expresses that the protocol attribute `x` is read-only. The mypy docs + # https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + # explain why this is necessary for safety. TL;DR: we shouldn't be able to write + # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 . + @property + def Warning(self) -> Type[Exception]: + ... + + @property + def Error(self) -> Type[Exception]: + ... # Errors are divided into `InterfaceError`s (something went wrong in the database # driver) and `DatabaseError`s (something went wrong in the database). These are # both subclasses of `Error`, but we can't currently express this in type # annotations due to https://github.com/python/mypy/issues/8397 - InterfaceError: Type[Exception] - DatabaseError: Type[Exception] + @property + def InterfaceError(self) -> Type[Exception]: + ... + + @property + def DatabaseError(self) -> Type[Exception]: + ... # Everything below is a subclass of `DatabaseError`. @@ -128,7 +159,9 @@ class DBAPI2Module(Protocol): # - An integer was too big for its data type. # - An invalid date time was provided. # - A string contained a null code point. - DataError: Type[Exception] + @property + def DataError(self) -> Type[Exception]: + ... # Roughly: something went wrong in the database, but it's not within the application # programmer's control. Examples: @@ -138,28 +171,45 @@ class DBAPI2Module(Protocol): # - A serialisation failure occurred. # - The database ran out of resources, such as storage, memory, connections, etc. # - The database encountered an error from the operating system. - OperationalError: Type[Exception] + @property + def OperationalError(self) -> Type[Exception]: + ... # Roughly: we've given the database data which breaks a rule we asked it to enforce. # Examples: # - Stop, criminal scum! You violated the foreign key constraint # - Also check constraints, non-null constraints, etc. - IntegrityError: Type[Exception] + @property + def IntegrityError(self) -> Type[Exception]: + ... # Roughly: something went wrong within the database server itself. - InternalError: Type[Exception] + @property + def InternalError(self) -> Type[Exception]: + ... # Roughly: the application did something silly that needs to be fixed. Examples: # - We don't have permissions to do something. # - We tried to create a table with duplicate column names. # - We tried to use a reserved name. # - We referred to a column that doesn't exist. - ProgrammingError: Type[Exception] + @property + def ProgrammingError(self) -> Type[Exception]: + ... # Roughly: we've tried to do something that this database doesn't support. - NotSupportedError: Type[Exception] + @property + def NotSupportedError(self) -> Type[Exception]: + ... - def connect(self, **parameters: object) -> Connection: + # We originally wrote + # def connect(self, *args, **kwargs) -> Connection: ... + # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory + # positional argument. We can't make that part of the signature though, because + # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use + # the following slightly unusual workaround. + @property + def connect(self) -> Callable[..., Connection]: ... diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index c6c8a0315c..8a48ffc48d 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from abc import ABC, abstractmethod from typing import Generic, List, Optional, Tuple, TypeVar from synapse.types import StrCollection, UserID @@ -22,7 +22,8 @@ K = TypeVar("K") R = TypeVar("R") -class EventSource(Generic[K, R]): +class EventSource(ABC, Generic[K, R]): + @abstractmethod async def get_new_events( self, user: UserID, @@ -32,4 +33,4 @@ class EventSource(Generic[K, R]): is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: - ... + raise NotImplementedError() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 782ef09cf4..1db99b3c00 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -62,7 +62,7 @@ class TestSpamChecker: request_info: Collection[Tuple[str, str]], auth_provider_id: Optional[str], ) -> RegistrationBehaviour: - pass + return RegistrationBehaviour.ALLOW class DenyAll(TestSpamChecker): @@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker: username: Optional[str], request_info: Collection[Tuple[str, str]], ) -> RegistrationBehaviour: - pass + return RegistrationBehaviour.ALLOW class LegacyAllowAll(TestLegacyRegistrationSpamChecker): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index acfdcd3bca..d27422515c 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -63,7 +63,7 @@ from tests.http import ( get_test_ca_cert_file, ) from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.utils import default_config +from tests.utils import checked_cast, default_config logger = logging.getLogger(__name__) @@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(dummy_address) - assert isinstance(client_protocol, _WrappingProtocol) + # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91) + client_protocol = checked_cast( + _WrappingProtocol, client_factory.buildProtocol(dummy_address) + ) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase): server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() ).buildProtocol(dummy_address) - assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport @@ -1529,7 +1530,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None: def _wrap_server_factory_for_tls( factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None -) -> IProtocolFactory: +) -> TLSMemoryBIOFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate signed by our test CA, valid for the domains in `sanlist` diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index a817940730..22fdc7f5f2 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -43,6 +43,7 @@ from tests.http import ( ) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase +from tests.utils import checked_cast logger = logging.getLogger(__name__) @@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase): server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() ).buildProtocol(dummy_address) - assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport @@ -757,12 +757,14 @@ class MatrixFederationAgentTests(TestCase): assert isinstance(proxy_server, HTTPChannel) # fish the transports back out so that we can do the old switcheroo - s2c_transport = proxy_server.transport - assert isinstance(s2c_transport, FakeTransport) - client_protocol = s2c_transport.other - assert isinstance(client_protocol, _WrappingProtocol) - c2s_transport = client_protocol.transport - assert isinstance(c2s_transport, FakeTransport) + # To help mypy out with the various Protocols and wrappers and mocks, we do + # some explicit casting. Without the casts, we hit the bug I reported at + # https://github.com/Shoobx/mypy-zope/issues/91 . + # We also double-checked these casts at runtime (test-time) because I found it + # quite confusing to deduce these types in the first place! + s2c_transport = checked_cast(FakeTransport, proxy_server.transport) + client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -822,9 +824,9 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_proxy_with_no_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) - self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") - self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._hostStr, "proxy.com") + self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) def test_proxy_with_unsupported_scheme(self) -> None: @@ -834,25 +836,21 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) def test_proxy_with_http_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) - self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") - self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._hostStr, "proxy.com") + self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) def test_proxy_with_https_scheme(self) -> None: https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) - self.assertEqual( - https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" - ) - self.assertEqual( - https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888 - ) + proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com") + self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888) def _wrap_server_factory_for_tls( factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None -) -> IProtocolFactory: +) -> TLSMemoryBIOFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index c08954d887..5191e31a8a 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler from tests.logging import LoggerCleanupMixin from tests.server import FakeTransport, get_clock from tests.unittest import TestCase +from tests.utils import checked_cast def connect_logging_client( @@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # One log message, with a single trailing newline logs = server.data.decode("utf8").splitlines() @@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # Only the 7 infos made it through, the debugs were elided logs = server.data.splitlines() @@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # The 10 warnings made it through, the debugs and infos were elided logs = server.data.splitlines() @@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # The first five and last five warnings made it through, the debugs and # infos were elided diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 208ec44829..f4e1e7de43 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -43,6 +43,9 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): super().__init__(hs) self.recaptcha_attempts: List[Tuple[dict, str]] = [] + def is_enabled(self) -> bool: + return True + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 3325d43a2f..5fa3440691 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -425,7 +425,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def test_fn( event: EventBase, state_events: StateMap[EventBase] ) -> Tuple[bool, Optional[JsonDict]]: - if event.is_state and event.type == EventTypes.PowerLevels: + if event.is_state() and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { "room_id": event.room_id, diff --git a/tests/utils.py b/tests/utils.py index 15fabbc2d0..a0ac11bc5c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,7 +15,7 @@ import atexit import os -from typing import Any, Callable, Dict, List, Tuple, Union, overload +from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload import attr from typing_extensions import Literal, ParamSpec @@ -341,3 +341,27 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: context = await unpersisted_context.persist(event) await persistence_store.persist_event(event, context) + + +T = TypeVar("T") + + +def checked_cast(type: Type[T], x: object) -> T: + """A version of typing.cast that is checked at runtime. + + We have our own function for this for two reasons: + + 1. typing.cast itself is deliberately a no-op at runtime, see + https://docs.python.org/3/library/typing.html#typing.cast + 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91 + where mypy would erroneously consider `isinstance(x, type)` to be false in all + circumstances. + + For this to make sense, `T` needs to be something that `isinstance` can check; see + https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance + https://docs.python.org/3/glossary.html#term-abstract-base-class + https://docs.python.org/3/library/typing.html#typing.runtime_checkable + for more details. + """ + assert isinstance(x, type) + return x -- cgit 1.5.1 From 4f4f27e57fdab1d7cc6e275b8acabc785952205e Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 17 Feb 2023 09:40:32 +0000 Subject: Mitigate a race where /make_join could 403 for restricted rooms (#15080) Previously, when creating a join event in /make_join, we would decide whether to include additional fields to satisfy restricted room checks based on the current state of the room. Then, when building the event, we would capture the forward extremities of the room to use as prev events. This is subject to race conditions. For example, when leaving and rejoining a room, the following sequence of events leads to a misleading 403 response: 1. /make_join reads the current state of the room and sees that the user is still in the room. It decides to omit the field required for restricted room joins. 2. The leave event is persisted and the room's forward extremities are updated. 3. /make_join builds the event, using the post-leave forward extremities. The event then fails the restricted room checks. To mitigate the race, we move the read of the forward extremities closer to the read of the current state. Ideally, we would compute the state based off the chosen prev events, but that can involve state resolution, which is expensive. Signed-off-by: Sean Quah --- changelog.d/15080.bugfix | 1 + synapse/handlers/federation.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/15080.bugfix (limited to 'synapse') diff --git a/changelog.d/15080.bugfix b/changelog.d/15080.bugfix new file mode 100644 index 0000000000..965d0b921e --- /dev/null +++ b/changelog.d/15080.bugfix @@ -0,0 +1 @@ +Reduce the likelihood of a rare race condition where rejoining a restricted room over federation would fail. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1d0f6bcd6f..5f2057269d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -952,7 +952,20 @@ class FederationHandler: # # Note that this requires the /send_join request to come back to the # same server. + prev_event_ids = None if room_version.msc3083_join_rules: + # Note that the room's state can change out from under us and render our + # nice join rules-conformant event non-conformant by the time we build the + # event. When this happens, our validation at the end fails and we respond + # to the requesting server with a 403, which is misleading — it indicates + # that the user is not allowed to join the room and the joining server + # should not bother retrying via this homeserver or any others, when + # in fact we've just messed up with building the event. + # + # To reduce the likelihood of this race, we capture the forward extremities + # of the room (prev_event_ids) just before fetching the current state, and + # hope that the state we fetch corresponds to the prev events we chose. + prev_event_ids = await self.store.get_prev_events_for_room(room_id) state_ids = await self._state_storage_controller.get_current_state_ids( room_id ) @@ -994,7 +1007,8 @@ class FederationHandler: event, unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( - builder=builder + builder=builder, + prev_event_ids=prev_event_ids, ) except SynapseError as e: logger.warning("Failed to create join to %s because %s", room_id, e) -- cgit 1.5.1 From 61bfcd669ae596a8df940f434e3e2335059100b1 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 17 Feb 2023 14:54:55 +0100 Subject: Add account data to export command (#14969) * Add account data to to export command * newsfile * remove not needed function * update newsfile * adopt #14973 --- changelog.d/14969.feature | 1 + docs/usage/administration/admin_faq.md | 3 +++ synapse/app/admin_cmd.py | 15 ++++++++++- synapse/handlers/admin.py | 49 +++++++++++++++++++++++----------- tests/handlers/test_admin.py | 27 +++++++++++++++++++ 5 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 changelog.d/14969.feature (limited to 'synapse') diff --git a/changelog.d/14969.feature b/changelog.d/14969.feature new file mode 100644 index 0000000000..a4680ef9c8 --- /dev/null +++ b/changelog.d/14969.feature @@ -0,0 +1 @@ +Add account data to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.78/usage/administration/admin_faq.html#how-can-i-export-user-data). \ No newline at end of file diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md index 7a27741199..925e1d175e 100644 --- a/docs/usage/administration/admin_faq.md +++ b/docs/usage/administration/admin_faq.md @@ -71,6 +71,9 @@ output-directory │ ├───invite_state │ └───knock_state └───user_data + ├───account_data + │ ├───global + │ └─── ├───connections ├───devices └───profile diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index fe7afb9475..ad51f33165 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,7 +17,7 @@ import logging import os import sys import tempfile -from typing import List, Optional +from typing import List, Mapping, Optional from twisted.internet import defer, task @@ -222,6 +222,19 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(connection_file, "a") as f: print(json.dumps(connection), file=f) + def write_account_data( + self, file_name: str, account_data: Mapping[str, JsonDict] + ) -> None: + account_data_directory = os.path.join( + self.base_directory, "user_data", "account_data" + ) + os.makedirs(account_data_directory, exist_ok=True) + + account_data_file = os.path.join(account_data_directory, file_name) + + with open(account_data_file, "a") as f: + print(json.dumps(account_data), file=f) + def finished(self) -> str: return self.base_directory diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index b03c214b14..8b7760b2cc 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,7 +14,7 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set from synapse.api.constants import Direction, Membership from synapse.events import EventBase @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main + self._store = hs.get_datastores().main self._device_handler = hs.get_device_handler() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -38,7 +38,7 @@ class AdminHandler: async def get_whois(self, user: UserID) -> JsonDict: connections = [] - sessions = await self.store.get_user_ip_and_agents(user) + sessions = await self._store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -57,7 +57,7 @@ class AdminHandler: async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" - user_info_dict = await self.store.get_user_by_id(user.to_string()) + user_info_dict = await self._store.get_user_by_id(user.to_string()) if user_info_dict is None: return None @@ -89,11 +89,11 @@ class AdminHandler: } # Add additional user metadata - profile = await self.store.get_profileinfo(user.localpart) - threepids = await self.store.user_get_threepids(user.to_string()) + profile = await self._store.get_profileinfo(user.localpart) + threepids = await self._store.user_get_threepids(user.to_string()) external_ids = [ ({"auth_provider": auth_provider, "external_id": external_id}) - for auth_provider, external_id in await self.store.get_external_ids_by_user( + for auth_provider, external_id in await self._store.get_external_ids_by_user( user.to_string() ) ] @@ -101,7 +101,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids - user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) + user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) return user_info_dict @@ -117,7 +117,7 @@ class AdminHandler: The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = await self.store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -131,7 +131,7 @@ class AdminHandler: # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those separately. - rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -140,7 +140,7 @@ class AdminHandler: "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = await self.store.did_forget(user_id, room_id) + forgotten = await self._store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -152,14 +152,14 @@ class AdminHandler: if room.membership == Membership.INVITE: event_id = room.event_id - invite = await self.store.get_event(event_id, allow_none=True) + invite = await self._store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) if room.membership == Membership.KNOCK: event_id = room.event_id - knock = await self.store.get_event(event_id, allow_none=True) + knock = await self._store.get_event(event_id, allow_none=True) if knock: knock_state = knock.unsigned["knock_room_state"] writer.write_knock(room_id, knock, knock_state) @@ -170,7 +170,7 @@ class AdminHandler: # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = self.store.get_room_max_stream_ordering() + stream_ordering = self._store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -197,7 +197,7 @@ class AdminHandler: # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = await self.store.paginate_room_events( + events, _ = await self._store.paginate_room_events( room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: @@ -263,6 +263,13 @@ class AdminHandler: connections["devices"][""]["sessions"][0]["connections"] ) + # Get all account data the user has global and in rooms + global_data = await self._store.get_global_account_data_for_user(user_id) + by_room_data = await self._store.get_room_account_data_for_user(user_id) + writer.write_account_data("global", global_data) + for room_id in by_room_data: + writer.write_account_data(room_id, by_room_data[room_id]) + return writer.finished() @@ -340,6 +347,18 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): """ raise NotImplementedError() + @abc.abstractmethod + def write_account_data( + self, file_name: str, account_data: Mapping[str, JsonDict] + ) -> None: + """Write the account data of a user. + + Args: + file_name: file name to write data + account_data: mapping of global or room account_data + """ + raise NotImplementedError() + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 6f300b8e11..1b97aaeed1 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -296,3 +296,30 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[0][0]["user_agent"], "user_agent") self.assertGreater(args[0][0]["last_seen"], 0) self.assertNotIn("access_token", args[0][0]) + + def test_account_data(self) -> None: + """Tests that user account data get exported.""" + # add account data + self.get_success( + self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1}) + ) + self.get_success( + self._store.add_account_data_to_room( + self.user2, "test_room", "m.per_room", {"b": 2} + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + # two calls, one call for user data and one call for room data + writer.write_account_data.assert_called() + + args = writer.write_account_data.call_args_list[0][0] + self.assertEqual(args[0], "global") + self.assertEqual(args[1]["m.global"]["a"], 1) + + args = writer.write_account_data.call_args_list[1][0] + self.assertEqual(args[0], "test_room") + self.assertEqual(args[1]["m.per_room"]["b"], 2) -- cgit 1.5.1 From 1cbc3f197cc1b9732649ffb769b05d90c0e904d7 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 20 Feb 2023 12:00:18 +0000 Subject: Fix a bug introduced in Synapse v1.74.0 where searching with colons when using ICU for search term tokenisation would fail with an error. (#15079) Co-authored-by: David Robertson --- changelog.d/15079.bugfix | 1 + synapse/storage/databases/main/user_directory.py | 24 +++++++-- tests/handlers/test_user_directory.py | 7 +++ tests/storage/test_user_directory.py | 63 +++++++++++++++++++++++- 4 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 changelog.d/15079.bugfix (limited to 'synapse') diff --git a/changelog.d/15079.bugfix b/changelog.d/15079.bugfix new file mode 100644 index 0000000000..907892c1ef --- /dev/null +++ b/changelog.d/15079.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.74.0 where searching with colons when using ICU for search term tokenisation would fail with an error. \ No newline at end of file diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f6a6fd4079..30af4b3b6c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -918,11 +918,19 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ - results = _parse_words(search_term) + escaped_words = [] + for word in _parse_words(search_term): + # Postgres tsvector and tsquery quoting rules: + # words potentially containing punctuation should be quoted + # and then existing quotes and backslashes should be doubled + # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY + + quoted_word = word.replace("'", "''").replace("\\", "\\\\") + escaped_words.append(f"'{quoted_word}'") - both = " & ".join("(%s:* | %s)" % (result, result) for result in results) - exact = " & ".join("%s" % (result,) for result in results) - prefix = " & ".join("%s:*" % (result,) for result in results) + both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words) + exact = " & ".join("%s" % (word,) for word in escaped_words) + prefix = " & ".join("%s:*" % (word,) for word in escaped_words) return both, exact, prefix @@ -944,6 +952,14 @@ def _parse_words(search_term: str) -> List[str]: if USE_ICU: return _parse_words_with_icu(search_term) + return _parse_words_with_regex(search_term) + + +def _parse_words_with_regex(search_term: str) -> List[str]: + """ + Break down search term into words, when we don't have ICU available. + See: `_parse_words` + """ return re.findall(r"([\w\-]+)", search_term, re.UNICODE) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f65a68b9c2..a02c1c6227 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -192,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, self.appservice.sender, tok=self.appservice.token) self._check_only_one_user_in_directory(user, room) + def test_search_term_with_colon_in_it_does_not_raise(self) -> None: + """ + Regression test: Test that search terms with colons in them are acceptable. + """ + u1 = self.register_user("user1", "pass") + self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10)) + def test_user_not_in_users_table(self) -> None: """Unclear how it happens, but on matrix.org we've seen join events for users who aren't in the users table. Test that we don't fall over diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index f1ca523d23..2d169684cf 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.background_updates import _BackgroundUpdateHandler +from synapse.storage.databases.main import user_directory +from synapse.storage.databases.main.user_directory import ( + _parse_words_with_icu, + _parse_words_with_regex, +) from synapse.storage.roommember import ProfileInfo from synapse.util import Clock @@ -42,7 +47,7 @@ ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" # The localpart isn't 'Bela' on purpose so we can test looking up display names. -BELA = "@somenickname:a" +BELA = "@somenickname:example.org" class GetUserDirectoryTables: @@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): + use_icu = False + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) + self._restore_use_icu = user_directory.USE_ICU + user_directory.USE_ICU = self.use_icu + + def tearDown(self) -> None: + user_directory.USE_ICU = self._restore_use_icu + def test_search_user_dir(self) -> None: # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. @@ -478,6 +491,26 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_start_of_user_id(self) -> None: + """Tests that a user can look up another user by searching for the start + of their user ID. + """ + r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) + + +class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase): + use_icu = True + + if not icu: + skip = "Requires PyICU" + class UserDirectoryICUTestCase(HomeserverTestCase): if not icu: @@ -513,3 +546,31 @@ class UserDirectoryICUTestCase(HomeserverTestCase): r["results"][0], {"user_id": ALICE, "display_name": display_name, "avatar_url": None}, ) + + def test_icu_word_boundary_punctuation(self) -> None: + """ + Tests the behaviour of punctuation with the ICU tokeniser. + + Seems to depend on underlying version of ICU. + """ + + # Note: either tokenisation is fine, because Postgres actually splits + # words itself afterwards. + self.assertIn( + _parse_words_with_icu("lazy'fox jumped:over the.dog"), + ( + # ICU 66 on Ubuntu 20.04 + ["lazy'fox", "jumped", "over", "the", "dog"], + # ICU 70 on Ubuntu 22.04 + ["lazy'fox", "jumped:over", "the.dog"], + ), + ) + + def test_regex_word_boundary_punctuation(self) -> None: + """ + Tests the behaviour of punctuation with the non-ICU tokeniser + """ + self.assertEqual( + _parse_words_with_regex("lazy'fox jumped:over the.dog"), + ["lazy", "fox", "jumped", "over", "the", "dog"], + ) -- cgit 1.5.1 From 490a3675bd7225b5695e505fea225d7c30127551 Mon Sep 17 00:00:00 2001 From: realtyem Date: Mon, 20 Feb 2023 06:23:00 -0600 Subject: Allow health listener resource to load (#15096) * Allow health listener resource to load. * changelog * Update changelog.d/15096.bugfix --- changelog.d/15096.bugfix | 1 + synapse/config/server.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/15096.bugfix (limited to 'synapse') diff --git a/changelog.d/15096.bugfix b/changelog.d/15096.bugfix new file mode 100644 index 0000000000..09b4d861f8 --- /dev/null +++ b/changelog.d/15096.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.76 where workers would fail to start if the `health` listener was configured. diff --git a/synapse/config/server.py b/synapse/config/server.py index ecdaa2d9dd..d4ef9930b0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -177,6 +177,7 @@ KNOWN_RESOURCES = { "client", "consent", "federation", + "health", "keys", "media", "metrics", -- cgit 1.5.1 From e26d7d5ae786df8d9d9a4dbd0f734e5c2f08aafd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 20 Feb 2023 13:35:24 +0000 Subject: Teach portdb about `un_partial_stated_event_stream` (#15108) * Sort BOOLEAN_COLUMNS and APPEND_ONLY_TABLES So I can see if a given table is present in logarithmic time, rather than linear. * Teach portdb about `un_partial_stated_event_streams` * Comments comments comments * Changelog --- changelog.d/15108.bugfix | 1 + synapse/_scripts/synapse_port_db.py | 85 +++++++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 33 deletions(-) create mode 100644 changelog.d/15108.bugfix (limited to 'synapse') diff --git a/changelog.d/15108.bugfix b/changelog.d/15108.bugfix new file mode 100644 index 0000000000..30af8b439d --- /dev/null +++ b/changelog.d/15108.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.75 where the [portdb script](https://matrix-org.github.io/synapse/release-v1.78/postgres.html#porting-from-sqlite) would fail to run after a room had been faster-joined. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5e137dbbf7..0d35e0af8f 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -94,61 +94,80 @@ reactor = cast(ISynapseReactor, reactor_) logger = logging.getLogger("synapse_port_db") +# SQLite doesn't have a dedicated boolean type (it stores True/False as 1/0). This means +# portdb will read sqlite bools as integers, then try to insert them into postgres +# boolean columns---which fails. Lacking some Python-parseable metaschema, we must +# specify which integer columns should be inserted as booleans into postgres. BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public", "has_auth_chain_index"], + "access_tokens": ["used"], + "account_validity": ["email_sent"], + "device_lists_changes_in_room": ["converted_to_destinations"], + "device_lists_outbound_pokes": ["sent"], + "devices": ["hidden"], + "e2e_fallback_keys_json": ["used"], + "e2e_room_keys": ["is_verified"], "event_edges": ["is_state"], + "events": ["processed", "outlier", "contains_url"], + "local_media_repository": ["safe_from_quarantine"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], "public_room_list_stream": ["visibility"], - "devices": ["hidden"], - "device_lists_outbound_pokes": ["sent"], - "users_who_share_rooms": ["share_private"], - "e2e_room_keys": ["is_verified"], - "account_validity": ["email_sent"], + "pushers": ["enabled"], "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], - "local_media_repository": ["safe_from_quarantine"], + "rooms": ["is_public", "has_auth_chain_index"], "users": ["shadow_banned", "approved"], - "e2e_fallback_keys_json": ["used"], - "access_tokens": ["used"], - "device_lists_changes_in_room": ["converted_to_destinations"], - "pushers": ["enabled"], + "un_partial_stated_event_stream": ["rejection_status_changed"], + "users_who_share_rooms": ["share_private"], } +# These tables are never deleted from in normal operation [*], so we can resume porting +# over rows from a previous attempt rather than starting from scratch. +# +# [*]: We do delete from many of these tables when purging a room, and +# presumably when purging old events. So we might e.g. +# +# 1. Run portdb and port half of some table. +# 2. Stop portdb. +# 3. Purge something, deleting some of the rows we've ported over. +# 4. Restart portdb. The rows deleted from sqlite are still present in postgres. +# +# But this isn't the end of the world: we should be able to repeat the purge +# on the postgres DB when porting completes. APPEND_ONLY_TABLES = [ + "cache_invalidation_stream_by_instance", + "event_auth", + "event_edges", + "event_json", "event_reference_hashes", + "event_search", + "event_to_state_groups", "events", - "event_json", - "state_events", - "room_memberships", - "topics", - "room_names", - "rooms", + "ex_outlier_stream", "local_media_repository", "local_media_repository_thumbnails", + "presence_stream", + "public_room_list_stream", + "push_rules_stream", + "received_transactions", + "redactions", + "rejections", "remote_media_cache", "remote_media_cache_thumbnails", - "redactions", - "event_edges", - "event_auth", - "received_transactions", + "room_memberships", + "room_names", + "rooms", "sent_transactions", - "transaction_id_to_pdu", - "users", + "state_events", + "state_group_edges", "state_groups", "state_groups_state", - "event_to_state_groups", - "rejections", - "event_search", - "presence_stream", - "push_rules_stream", - "ex_outlier_stream", - "cache_invalidation_stream_by_instance", - "public_room_list_stream", - "state_group_edges", "stream_ordering_to_exterm", + "topics", + "transaction_id_to_pdu", + "un_partial_stated_event_stream", + "users", ] -- cgit 1.5.1 From addd12f16dc35a4f82cb48807719909e7aed9dcb Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 21 Feb 2023 12:26:00 +0000 Subject: Tweak logging for when a worker waits for its view of a replication stream to catch up. (#15120)Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Improve logging messages for the 'wait for repl stream' read-after-write consistency feature * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Update synapse/replication/tcp/client.py Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --------- Signed-off-by: Olivier Wilkinson (reivilibre) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/15120.misc | 1 + synapse/replication/tcp/client.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 changelog.d/15120.misc (limited to 'synapse') diff --git a/changelog.d/15120.misc b/changelog.d/15120.misc new file mode 100644 index 0000000000..ebbc0c9027 --- /dev/null +++ b/changelog.d/15120.misc @@ -0,0 +1 @@ +Tweak logging for when a worker waits for its view of a replication stream to catch up. \ No newline at end of file diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cc0528bd8e..424854efbe 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -370,15 +370,23 @@ class ReplicationDataHandler: # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): logger.info( - "Waiting for repl stream %r to reach %s (%s)", + "Waiting for repl stream %r to reach %s (%s); currently at: %s", stream_name, position, instance_name, + current_position, ) try: await make_deferred_yieldable(deferred) except defer.TimeoutError: - logger.error("Timed out waiting for stream %s", stream_name) + logger.error( + "Timed out waiting for repl stream %r to reach %s (%s)" + "; currently at: %s", + stream_name, + position, + instance_name, + self._streams[stream_name].current_token(instance_name), + ) return logger.info( -- cgit 1.5.1 From 647ff3ef65e7a54b2719755802b4e6f2f45f5eb6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 22 Feb 2023 11:07:28 +0000 Subject: Remove unused `room_alias` field from `/createRoom` response (#15093) * Change `create_room` return type * Don't return room alias from /createRoom * Update other callsites * Fix up mypy complaints It looks like new_room_user_id is None iff new_room_id is None. It's a shame we haven't expressed this in a way that mypy can understand. * Changelog --- changelog.d/15093.bugfix | 1 + synapse/handlers/register.py | 4 +-- synapse/handlers/room.py | 38 ++++++++++++------------ synapse/module_api/__init__.py | 6 ++-- synapse/rest/client/room.py | 4 +-- synapse/server_notices/server_notices_manager.py | 3 +- tests/storage/test_cleanup_extrems.py | 8 ++--- tests/storage/test_event_metrics.py | 3 +- tests/storage/test_receipts.py | 10 ++++--- tests/test_federation.py | 2 +- 10 files changed, 40 insertions(+), 39 deletions(-) create mode 100644 changelog.d/15093.bugfix (limited to 'synapse') diff --git a/changelog.d/15093.bugfix b/changelog.d/15093.bugfix new file mode 100644 index 0000000000..00f1c19391 --- /dev/null +++ b/changelog.d/15093.bugfix @@ -0,0 +1 @@ +Remove the unspecced `room_alias` field from the [`/createRoom`](https://spec.matrix.org/v1.6/client-server-api/#post_matrixclientv3createroom) response. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c611efb760..e4e506e62c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -476,7 +476,7 @@ class RegistrationHandler: # create room expects the localpart of the room alias config["room_alias_name"] = room_alias.localpart - info, _ = await room_creation_handler.create_room( + room_id, _, _ = await room_creation_handler.create_room( fake_requester, config=config, ratelimit=False, @@ -490,7 +490,7 @@ class RegistrationHandler: user_id, authenticated_entity=self._server_name ), target=UserID.from_string(user_id), - room_id=info["room_id"], + room_id=room_id, # Since it was just created, there are no remote hosts. remote_room_hosts=[], action="join", diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 837dabb3b7..37c87c8351 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -690,13 +690,14 @@ class RoomCreationHandler: config: JsonDict, ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, - ) -> Tuple[dict, int]: + ) -> Tuple[str, Optional[RoomAlias], int]: """Creates a new room. Args: - requester: - The user who requested the room creation. - config : A dict of configuration options. + requester: The user who requested the room creation. + config: A dict of configuration options. This will be the body of + a /createRoom request; see + https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom ratelimit: set to False to disable the rate limiter creator_join_profile: @@ -707,14 +708,17 @@ class RoomCreationHandler: `avatar_url` and/or `displayname`. Returns: - First, a dict containing the keys `room_id` and, if an alias - was, requested, `room_alias`. Secondly, the stream_id of the - last persisted event. + A 3-tuple containing: + - the room ID; + - if requested, the room alias, otherwise None; and + - the `stream_id` of the last persisted event. Raises: - SynapseError if the room ID couldn't be stored, 3pid invitation config - validation failed, or something went horribly wrong. - ResourceLimitError if server is blocked to some resource being - exceeded + SynapseError: + if the room ID couldn't be stored, 3pid invitation config + validation failed, or something went horribly wrong. + ResourceLimitError: + if server is blocked to some resource being + exceeded """ user_id = requester.user.to_string() @@ -1024,11 +1028,6 @@ class RoomCreationHandler: last_sent_event_id = member_event_id depth += 1 - result = {"room_id": room_id} - - if room_alias: - result["room_alias"] = room_alias.to_string() - # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), @@ -1036,7 +1035,7 @@ class RoomCreationHandler: last_stream_id, ) - return result, last_stream_id + return room_id, room_alias, last_stream_id async def _send_events_for_new_room( self, @@ -1825,7 +1824,7 @@ class RoomShutdownHandler: new_room_user_id, authenticated_entity=requester_user_id ) - info, stream_id = await self._room_creation_handler.create_room( + new_room_id, _, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": RoomCreationPreset.PUBLIC_CHAT, @@ -1834,7 +1833,6 @@ class RoomShutdownHandler: }, ratelimit=False, ) - new_room_id = info["room_id"] logger.info( "Shutting down room %r, joining to new room: %r", room_id, new_room_id @@ -1887,6 +1885,7 @@ class RoomShutdownHandler: # Join users to new room if new_room_user_id: + assert new_room_id is not None await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, @@ -1919,6 +1918,7 @@ class RoomShutdownHandler: aliases_for_room = await self.store.get_aliases_for_room(room_id) + assert new_room_id is not None await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d22dd19d38..1964276a54 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1576,14 +1576,14 @@ class ModuleApi: ) requester = create_requester(user_id) - room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room( + room_id, room_alias, _ = await self._hs.get_room_creation_handler().create_room( requester=requester, config=config, ratelimit=ratelimit, creator_join_profile=creator_join_profile, ) - - return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None) + room_alias_str = room_alias.to_string() if room_alias else None + return room_id, room_alias_str async def set_displayname( self, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index d0db85cca7..14b04810a1 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -160,11 +160,11 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) - return 200, info + return 200, {"room_id": room_id} def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 564e3705c2..9732dbdb6e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -178,7 +178,7 @@ class ServerNoticesManager: "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url, } - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -188,7 +188,6 @@ class ServerNoticesManager: ratelimit=False, creator_join_profile=join_profile, ) - room_id = info["room_id"] self.maybe_get_notice_room_for_user.invalidate((user_id,)) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index d570684c99..7de109966d 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -43,8 +43,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = create_requester(self.user) - info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) - self.room_id = info["room_id"] + self.room_id, _, _ = self.get_success( + self.room_creator.create_room(self.requester, {}) + ) def run_background_update(self) -> None: """Re run the background update to clean up the extremities.""" @@ -275,10 +276,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = create_requester(self.user) - info, _ = self.get_success( + self.room_id, _, _ = self.get_success( self.room_creator.create_room(self.requester, {"visibility": "public"}) ) - self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.consent.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a91411168c..6897addbd3 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,8 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info, _ = self.get_success(room_creator.create_room(requester, {})) - room_id = info["room_id"] + room_id, _, _ = self.get_success(room_creator.create_room(requester, {})) last_event = None diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 12c17f1073..1b52eef23f 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -50,12 +50,14 @@ class ReceiptTestCase(HomeserverTestCase): self.otherRequester = create_requester(self.otherUser) # Create a test room - info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) - self.room_id1 = info["room_id"] + self.room_id1, _, _ = self.get_success( + self.room_creator.create_room(self.ourRequester, {}) + ) # Create a second test room - info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) - self.room_id2 = info["room_id"] + self.room_id2, _, _ = self.get_success( + self.room_creator.create_room(self.ourRequester, {}) + ) # Join the second user to the first room memberEvent, memberEventContext = self.get_success( diff --git a/tests/test_federation.py b/tests/test_federation.py index 82dfd88b99..46d2f99eac 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -47,7 +47,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - )[0]["room_id"] + )[0] self.store = self.hs.get_datastores().main -- cgit 1.5.1 From 6def779a1a7c49cd10e635986fbfa1e422eb20bf Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 22 Feb 2023 20:29:39 +0100 Subject: Use `json.dump` in `FileExfiltrationWriter` (#15095) To directly write to the open file, instead of writing to an in-memory string first. --- changelog.d/15095.misc | 1 + synapse/app/admin_cmd.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 changelog.d/15095.misc (limited to 'synapse') diff --git a/changelog.d/15095.misc b/changelog.d/15095.misc new file mode 100644 index 0000000000..a2fafe2fff --- /dev/null +++ b/changelog.d/15095.misc @@ -0,0 +1 @@ +Refactor writing json data in `FileExfiltrationWriter`. \ No newline at end of file diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index ad51f33165..5003777f0d 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -149,7 +149,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(events_file, "a") as f: for event in events: - print(json.dumps(event.get_pdu_json()), file=f) + json.dump(event.get_pdu_json(), fp=f) def write_state( self, room_id: str, event_id: str, state: StateMap[EventBase] @@ -162,7 +162,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(event_file, "a") as f: for event in state.values(): - print(json.dumps(event.get_pdu_json()), file=f) + json.dump(event.get_pdu_json(), fp=f) def write_invite( self, room_id: str, event: EventBase, state: StateMap[EventBase] @@ -178,7 +178,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(invite_state, "a") as f: for event in state.values(): - print(json.dumps(event), file=f) + json.dump(event, fp=f) def write_knock( self, room_id: str, event: EventBase, state: StateMap[EventBase] @@ -194,7 +194,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(knock_state, "a") as f: for event in state.values(): - print(json.dumps(event), file=f) + json.dump(event, fp=f) def write_profile(self, profile: JsonDict) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -202,7 +202,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): profile_file = os.path.join(user_directory, "profile") with open(profile_file, "a") as f: - print(json.dumps(profile), file=f) + json.dump(profile, fp=f) def write_devices(self, devices: List[JsonDict]) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -211,7 +211,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): for device in devices: with open(device_file, "a") as f: - print(json.dumps(device), file=f) + json.dump(device, fp=f) def write_connections(self, connections: List[JsonDict]) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -220,7 +220,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): for connection in connections: with open(connection_file, "a") as f: - print(json.dumps(connection), file=f) + json.dump(connection, fp=f) def write_account_data( self, file_name: str, account_data: Mapping[str, JsonDict] @@ -233,7 +233,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): account_data_file = os.path.join(account_data_directory, file_name) with open(account_data_file, "a") as f: - print(json.dumps(account_data), file=f) + json.dump(account_data, fp=f) def finished(self) -> str: return self.base_directory -- cgit 1.5.1 From 4ed08ff72ef8f1abf85ab22de1e51b570f67b27e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Feb 2023 14:37:18 -0500 Subject: Tighten the default rate limit of creating new devices. (#15135) --- changelog.d/15135.misc | 1 + docs/usage/configuration/config_documentation.md | 6 +++--- synapse/config/ratelimiting.py | 13 +++++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 changelog.d/15135.misc (limited to 'synapse') diff --git a/changelog.d/15135.misc b/changelog.d/15135.misc new file mode 100644 index 0000000000..25c4dbffe1 --- /dev/null +++ b/changelog.d/15135.misc @@ -0,0 +1 @@ +Tighten the login ratelimit defaults. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 58c6955689..ab1f9f4963 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1518,11 +1518,11 @@ rc_registration_token_validity: This option specifies several limits for login: * `address` ratelimits login requests based on the client's IP - address. Defaults to `per_second: 0.17`, `burst_count: 3`. + address. Defaults to `per_second: 0.003`, `burst_count: 5`. * `account` ratelimits login requests based on the account the - client is attempting to log into. Defaults to `per_second: 0.17`, - `burst_count: 3`. + client is attempting to log into. Defaults to `per_second: 0.03`, + `burst_count: 5`. * `failed_attempts` ratelimits login requests based on the account the client is attempting to log into, based on the amount of failed login diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 5c13fe428a..b733fac617 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -87,9 +87,18 @@ class RatelimitConfig(Config): defaults={"per_second": 0.1, "burst_count": 5}, ) + # It is reasonable to login with a bunch of devices at once (i.e. when + # setting up an account), but it is *not* valid to continually be + # logging into new devices. rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {})) - self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {})) + self.rc_login_address = RatelimitSettings( + rc_login_config.get("address", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + self.rc_login_account = RatelimitSettings( + rc_login_config.get("account", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) self.rc_login_failed_attempts = RatelimitSettings( rc_login_config.get("failed_attempts", {}) ) -- cgit 1.5.1 From 9bb2eac71962970d02842bca441f4bcdbbf93a11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 15:29:09 -0500 Subject: Bump black from 22.12.0 to 23.1.0 (#15103) --- changelog.d/15103.misc | 1 + poetry.lock | 42 ++++++++++++++-------- stubs/sortedcontainers/sortedlist.pyi | 1 - synapse/_scripts/register_new_matrix_user.py | 2 -- synapse/_scripts/synapse_port_db.py | 1 - synapse/_scripts/synctl.py | 1 - synapse/app/_base.py | 2 +- synapse/app/complement_fork_starter.py | 2 +- synapse/app/generic_worker.py | 1 - synapse/app/homeserver.py | 1 - synapse/config/consent.py | 1 - synapse/config/database.py | 1 - synapse/config/homeserver.py | 1 - synapse/config/ratelimiting.py | 1 - synapse/config/repository.py | 1 - synapse/config/server.py | 1 - synapse/config/tls.py | 1 - synapse/crypto/keyring.py | 2 +- synapse/events/third_party_rules.py | 2 -- synapse/federation/send_queue.py | 4 +-- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 2 -- synapse/handlers/directory.py | 8 +++-- synapse/handlers/e2e_room_keys.py | 1 - synapse/handlers/event_auth.py | 1 - synapse/handlers/initial_sync.py | 1 - synapse/handlers/presence.py | 2 -- synapse/handlers/room.py | 8 +++-- synapse/handlers/room_batch.py | 2 +- synapse/handlers/sync.py | 1 - synapse/logging/opentracing.py | 1 + synapse/metrics/__init__.py | 1 - synapse/metrics/_gc.py | 1 - synapse/push/bulk_push_rule_evaluator.py | 1 - synapse/replication/http/account_data.py | 1 - synapse/replication/http/devices.py | 1 - synapse/replication/tcp/redis.py | 1 - synapse/replication/tcp/streams/events.py | 1 - synapse/rest/admin/rooms.py | 4 --- synapse/rest/admin/users.py | 8 +++-- synapse/rest/client/auth.py | 1 - synapse/rest/client/filter.py | 1 - synapse/rest/client/register.py | 18 ++++++---- synapse/rest/media/v1/_base.py | 1 - synapse/rest/media/v1/thumbnailer.py | 1 - synapse/storage/databases/main/deviceinbox.py | 5 ++- synapse/storage/databases/main/devices.py | 4 +-- synapse/storage/databases/main/e2e_room_keys.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 8 ++--- synapse/storage/databases/main/event_federation.py | 1 - synapse/storage/databases/main/events.py | 1 - .../storage/databases/main/events_bg_updates.py | 4 +-- synapse/storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/media_repository.py | 1 - synapse/storage/databases/main/pusher.py | 3 -- synapse/storage/databases/main/receipts.py | 1 - synapse/storage/databases/main/room.py | 1 - synapse/storage/databases/main/search.py | 2 -- synapse/storage/databases/main/state.py | 1 - synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/stream.py | 1 + synapse/storage/databases/main/transactions.py | 1 - synapse/storage/databases/main/user_directory.py | 1 - synapse/storage/databases/state/bg_updates.py | 1 - synapse/storage/databases/state/store.py | 7 ++-- synapse/storage/prepare_database.py | 4 +-- synapse/types/state.py | 2 +- synapse/util/caches/__init__.py | 1 - synapse/util/check_dependencies.py | 2 +- synapse/util/patch_inline_callbacks.py | 1 - synmark/__main__.py | 2 -- synmark/suites/logging.py | 1 - tests/federation/test_complexity.py | 4 --- tests/federation/test_federation_server.py | 1 - tests/handlers/test_sso.py | 1 - tests/handlers/test_stats.py | 1 - tests/http/federation/test_srv_resolver.py | 1 - tests/http/test_client.py | 2 +- tests/push/test_bulk_push_rule_evaluator.py | 1 - tests/push/test_email.py | 2 -- tests/replication/slave/storage/test_events.py | 1 - tests/rest/admin/test_device.py | 3 -- tests/rest/admin/test_media.py | 5 --- tests/rest/admin/test_room.py | 1 - tests/rest/admin/test_server_notice.py | 1 - tests/rest/client/test_account.py | 4 --- tests/rest/client/test_auth.py | 2 -- tests/rest/client/test_capabilities.py | 1 - tests/rest/client/test_consent.py | 1 - tests/rest/client/test_directory.py | 1 - tests/rest/client/test_ephemeral_message.py | 1 - tests/rest/client/test_events.py | 3 -- tests/rest/client/test_filter.py | 1 - tests/rest/client/test_login.py | 2 -- tests/rest/client/test_login_token_request.py | 1 - tests/rest/client/test_presence.py | 1 - tests/rest/client/test_profile.py | 3 -- tests/rest/client/test_register.py | 4 --- tests/rest/client/test_rendezvous.py | 1 - tests/rest/client/test_rooms.py | 14 ++------ tests/rest/client/test_sync.py | 3 -- tests/rest/client/test_third_party_rules.py | 3 ++ tests/rest/media/test_media_retention.py | 1 - tests/rest/media/v1/test_media_storage.py | 3 -- tests/rest/media/v1/test_url_preview.py | 3 -- tests/server_notices/test_consent.py | 2 -- tests/storage/databases/main/test_deviceinbox.py | 1 - tests/storage/databases/main/test_receipts.py | 2 +- tests/storage/databases/main/test_room.py | 1 - tests/storage/test_client_ips.py | 1 - tests/storage/test_event_chain.py | 2 -- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_purge.py | 1 - tests/storage/test_roommember.py | 3 -- tests/storage/test_state.py | 30 ++++++++-------- tests/test_mau.py | 1 - 117 files changed, 108 insertions(+), 218 deletions(-) create mode 100644 changelog.d/15103.misc (limited to 'synapse') diff --git a/changelog.d/15103.misc b/changelog.d/15103.misc new file mode 100644 index 0000000000..65322498c9 --- /dev/null +++ b/changelog.d/15103.misc @@ -0,0 +1 @@ +Bump black from 22.12.0 to 23.1.0. diff --git a/poetry.lock b/poetry.lock index 4d724ab782..8ffdab7a22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -90,32 +90,46 @@ typecheck = ["mypy"] [[package]] name = "black" -version = "22.12.0" +version = "23.1.0" description = "The uncompromising code formatter." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, - {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, - {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"}, - {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"}, - {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"}, - {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"}, - {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"}, - {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"}, - {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"}, - {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"}, - {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"}, - {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b"}, + {file = "black-23.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"}, + {file = "black-23.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958"}, + {file = "black-23.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a"}, + {file = "black-23.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481"}, + {file = "black-23.1.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad"}, + {file = "black-23.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8"}, + {file = "black-23.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580"}, + {file = "black-23.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468"}, + {file = "black-23.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739"}, + {file = "black-23.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9"}, + {file = "black-23.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555"}, + {file = "black-23.1.0-py3-none-any.whl", hash = "sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32"}, + {file = "black-23.1.0.tar.gz", hash = "sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac"}, ] [package.dependencies] click = ">=8.0.0" mypy-extensions = ">=0.4.3" +packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index 1fe1a136f1..0e745c0a79 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -29,7 +29,6 @@ _Repr = Callable[[], str] def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... class SortedList(MutableSequence[_T]): - DEFAULT_LOAD_FACTOR: int = ... def __init__( self, diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 2b74a40166..19ca399d44 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -47,7 +47,6 @@ def request_registration( _print: Callable[[str], None] = print, exit: Callable[[int], None] = sys.exit, ) -> None: - url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) # Get the nonce @@ -154,7 +153,6 @@ def register_new_user( def main() -> None: - logging.captureWarnings(True) parser = argparse.ArgumentParser( diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 0d35e0af8f..2c9cbf8b27 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -1205,7 +1205,6 @@ class CursesProgress(Progress): if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) else: - if self.total_processed > 0: left = float(self.total_remaining) / self.total_processed diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py index b4c96ad7f3..077b90935e 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py @@ -167,7 +167,6 @@ Worker = collections.namedtuple( def main() -> None: - parser = argparse.ArgumentParser() parser.add_argument( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a5aa2185a2..28062dd69d 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -213,7 +213,7 @@ def handle_startup_exception(e: Exception) -> NoReturn: def redirect_stdio_to_logs() -> None: streams = [("stdout", LogLevel.info), ("stderr", LogLevel.error)] - for (stream, level) in streams: + for stream, level in streams: oldStream = getattr(sys, stream) loggingFile = LoggingFile( logger=twisted.logger.Logger(namespace=stream), diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 920538f44d..c8dc3f9d76 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -219,7 +219,7 @@ def main() -> None: # memory space and don't need to repeat the work of loading the code! # Instead of using fork() directly, we use the multiprocessing library, # which uses fork() on Unix platforms. - for (func, worker_args) in zip(worker_functions, args_by_worker): + for func, worker_args in zip(worker_functions, args_by_worker): process = multiprocessing.Process( target=_worker_entrypoint, args=(func, proxy_reactor, worker_args) ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 946f3a3807..0dec24369a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -157,7 +157,6 @@ class GenericWorkerServer(HomeServer): DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore def _listen_http(self, listener_config: ListenerConfig) -> None: - assert listener_config.http_options is not None # We always include a health resource. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6176a70eb2..b8830b1a9c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -321,7 +321,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer: and not config.registration.registrations_require_3pid and not config.registration.registration_requires_token ): - raise ConfigError( "You have enabled open registration without any verification. This is a known vector for " "spam and abuse. If you would like to allow public registration, please consider adding email, " diff --git a/synapse/config/consent.py b/synapse/config/consent.py index be74609dc4..5bfd0cbb71 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py @@ -22,7 +22,6 @@ from ._base import Config class ConsentConfig(Config): - section = "consent" def __init__(self, *args: Any): diff --git a/synapse/config/database.py b/synapse/config/database.py index 928fec8dfe..596d8769fe 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -154,7 +154,6 @@ class DatabaseConfig(Config): logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) def set_databasepath(self, database_path: str) -> None: - if database_path != ":memory:": database_path = self.abspath(database_path) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4d2b298a70..c205a78039 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -56,7 +56,6 @@ from .workers import WorkerConfig class HomeServerConfig(RootConfig): - config_classes = [ ModulesConfig, ServerConfig, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index b733fac617..a5514e70a2 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -46,7 +46,6 @@ class RatelimitConfig(Config): section = "ratelimiting" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index e4759711ed..2da40c09f0 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -116,7 +116,6 @@ class ContentRepositoryConfig(Config): section = "media" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( diff --git a/synapse/config/server.py b/synapse/config/server.py index d4ef9930b0..0e46b849cf 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -735,7 +735,6 @@ class ServerConfig(Config): listeners: Optional[List[dict]], **kwargs: Any, ) -> str: - _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 336fe3e0da..318270ebb8 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -30,7 +30,6 @@ class TlsConfig(Config): section = "tls" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 86cd4af9bd..d710607c63 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -399,7 +399,7 @@ class Keyring: # We now convert the returned list of results into a map from server # name to key ID to FetchKeyResult, to return. to_return: Dict[str, Dict[str, FetchKeyResult]] = {} - for (request, results) in zip(deduped_requests, results_per_request): + for request, results in zip(deduped_requests, results_per_request): to_return_by_server = to_return.setdefault(request.server_name, {}) for key_id, key_result in results.items(): existing = to_return_by_server.get(key_id) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 97c61cc258..9a25ed419b 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -78,7 +78,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: # correctly, we need to await its result. Therefore it doesn't make a lot of # sense to make it go through the run() wrapper. if f.__name__ == "check_event_allowed": - # We need to wrap check_event_allowed because its old form would return either # a boolean or a dict, but now we want to return the dict separately from the # boolean. @@ -100,7 +99,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: return wrap_check_event_allowed if f.__name__ == "on_create_room": - # We need to wrap on_create_room because its old form would return a boolean # if the room creation is denied, but now we just want it to raise an # exception. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index d720b5fd3f..3063df7990 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -314,7 +314,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): # stream position. keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in keyed_edus.items(): + for (destination, edu_key), pos in keyed_edus.items(): rows.append( ( pos, @@ -329,7 +329,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): j = self.edus.bisect_right(to_token) + 1 edus = self.edus.items()[i:j] - for (pos, edu) in edus: + for pos, edu in edus: rows.append((pos, EduRow(edu))) # Sort rows based on pos diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5d1d21cdc8..ec3ab968e9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -737,7 +737,7 @@ class ApplicationServicesHandler: ) ret = [] - for (success, result) in results: + for success, result in results: if success: ret.extend(result) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index cf12b55d21..b12bc4c9a3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -815,7 +815,6 @@ class AuthHandler: now_ms = self._clock.time_msec() if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: - raise SynapseError( HTTPStatus.FORBIDDEN, "The supplied refresh token has expired", @@ -2259,7 +2258,6 @@ class PasswordAuthProvider: async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - # call all of the on_logged_out callbacks for callback in self.on_logged_out_callbacks: try: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a5798e9483..1fb23cc9bf 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -497,9 +497,11 @@ class DirectoryHandler: raise SynapseError(403, "Not allowed to publish room") # Check if publishing is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 83f53ceb88..50317ec753 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -188,7 +188,6 @@ class E2eRoomKeysHandler: # XXX: perhaps we should use a finer grained lock here? async with self._upload_linearizer.queue(user_id): - # Check that the version we're trying to upload is the current version try: version_info = await self.store.get_e2e_room_keys_version_info(user_id) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 46dd63c3f0..c508861b6a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -236,7 +236,6 @@ class EventAuthHandler: # in any of them. allowed_rooms = await self.get_rooms_that_allow_join(state_ids) if not await self.is_user_in_rooms(allowed_rooms, user_id): - # If this is a remote request, the user might be in an allowed room # that we do not know about. if get_domain_from_id(user_id) != self._server_name: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1a29abde98..aead0b44b9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -124,7 +124,6 @@ class InitialSyncHandler: as_client_event: bool = True, include_archived: bool = False, ) -> JsonDict: - memberships = [Membership.INVITE, Membership.JOIN] if include_archived: memberships.append(Membership.LEAVE) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 87af31aa27..4ad2233573 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler): ) if self.unpersisted_users_changes: - await self.store.update_presence( [ self.user_to_current_state[user_id] @@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler): now = self.clock.time_msec() with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 37c87c8351..a26ec02284 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -868,9 +868,11 @@ class RoomCreationHandler: ) # Check whether this visibility value is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index c73d2adaad..5d4ca0e2d2 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -374,7 +374,7 @@ class RoomBatchHandler: # correct stream_ordering as they are backfilled (which decrements). # Events are sorted by (topological_ordering, stream_ordering) # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): + for event, context in reversed(events_to_persist): # This call can't raise `PartialStateConflictError` since we forbid # use of the historical batch API during partial state await self.event_creation_handler.handle_new_client_event( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e4595312c..fd6d946c37 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1297,7 +1297,6 @@ class SyncHandler: return RoomNotifCounts.empty() with Measure(self.clock, "unread_notifs_for_room_id"): - return await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5aed71262f..c70eee649c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -524,6 +524,7 @@ def whitelisted_homeserver(destination: str) -> bool: # Start spans and scopes + # Could use kwargs but I want these to be explicit def start_active_span( operation_name: str, diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index b01372565d..8ce5887229 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -87,7 +87,6 @@ class LaterGauge(Collector): ] def collect(self) -> Iterable[Metric]: - g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) try: diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py index b7d47ce3e7..a22c4e5bbd 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py @@ -139,7 +139,6 @@ def install_gc_manager() -> None: class PyPyGCStats(Collector): def collect(self) -> Iterable[Metric]: - # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 5fc38431ba..8f834be774 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -330,7 +330,6 @@ class BulkPushRuleEvaluator: context: EventContext, event_id_to_event: Mapping[str, EventBase], ) -> None: - if ( not event.internal_metadata.is_notifiable() or event.internal_metadata.is_historical() diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 2374f810c9..111ec07e64 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -265,7 +265,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] - return {} async def _handle_request( # type: ignore[override] diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index ecea6fc915..cc3929dcf5 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -195,7 +195,6 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): async def _serialize_payload( # type: ignore[override] user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: - return { "user_id": user_id, "device_id": device_id, diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index fd1c0ec6af..dfc061eb5e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -328,7 +328,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): outbound_redis_connection: txredisapi.ConnectionHandler, channel_names: List[str], ): - super().__init__( hs, uuid="subscriber", diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 14b6705862..ad9b760713 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -139,7 +139,6 @@ class EventsStream(Stream): current_token: Token, target_row_count: int, ) -> StreamUpdateResult: - # the events stream merges together three separate sources: # * new events # * current_state changes diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1d6e4982d7..4de56bf13f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) @@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): @@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, delete_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) delete_status = self._pagination_handler.get_delete_status(delete_id) @@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 0c0bf540b9..7cc4db20d6 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -683,8 +683,12 @@ class AccountValidityRenewServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = await ( - self.account_activity_handler.on_legacy_admin_request_callback(request) + expiration_ts = ( + await ( + self.account_activity_handler.on_legacy_admin_request_callback( + request + ) + ) ) else: body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index eb77337044..276a1b405d 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet): return None async def on_POST(self, request: Request, stagetype: str) -> None: - session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index cc1c2f9731..236199897c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 3cb1e7e375..bce806f2bb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, + desired_username = ( + await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) ) ) @@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params + display_name = ( + await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) ) ) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 6e035afcce..ef8334ae25 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -270,7 +270,6 @@ async def respond_with_responder( logger.debug("Responding to media request with responder %s", responder) add_file_headers(request, media_type, file_size, upload_name) try: - await responder.write_to_consumer(request) except Exception as e: # The majority of the time this will be due to the client having gone diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 9480cc5763..f909a4fb9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -38,7 +38,6 @@ class ThumbnailError(Exception): class Thumbnailer: - FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} @staticmethod diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 8e61aba454..0d75d9739a 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - for (user_id, messages_by_device) in edu["messages"].items(): - for (device_id, msg) in messages_by_device.items(): + for user_id, messages_by_device in edu["messages"].items(): + for device_id, msg in messages_by_device.items(): with start_active_span("store_outgoing_to_device_message"): set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) @@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> Tuple[int, bool]: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1ca66d57d4..0dd15f16ff 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -512,7 +512,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): results.append(("org.matrix.signing_key_update", result)) if issue_8631_logger.isEnabledFor(logging.DEBUG): - for (user_id, edu) in results: + for user_id, edu in results: issue_8631_logger.debug( "device update to %s for %s from %s to %s: %s", destination, @@ -1316,7 +1316,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) """ count = 0 - for (destination, user_id, stream_id, device_id) in rows: + for destination, user_id, stream_id, device_id in rows: txn.execute( delete_sql, (destination, user_id, stream_id, stream_id, device_id) ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 6240f9a75e..9f8d2e4bea 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No backup with that version exists") values = [] - for (room_id, session_id, room_key) in room_keys: + for room_id, session_id, room_key in room_keys: values.append( ( user_id, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2c2d145666..b9c39b1718 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # add each cross-signing signature to the correct device in the result dict. - for (user_id, key_id, device_id, signature) in cross_sigs_result: + for user_id, key_id, device_id, signature in cross_sigs_result: target_device_result = result[user_id][device_id] # We've only looked up cross-signatures for non-deleted devices with key # data. @@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker # devices. user_list = [] user_device_list = [] - for (user_id, device_id) in query_list: + for user_id, device_id in query_list: if device_id is None: user_list.append(user_id) else: @@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker txn.execute(sql, query_params) - for (user_id, device_id, display_name, key_json) in txn: + for user_id, device_id, display_name, key_json in txn: assert device_id is not None if include_deleted_devices: deleted_devices.remove((user_id, device_id)) @@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_query_clauses = [] signature_query_params = [] - for (user_id, device_id) in device_query: + for user_id, device_id in device_query: signature_query_clauses.append( "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca780cca36..ff3edeb716 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas latest_events: List[str], limit: int, ) -> List[str]: - seen_events = set(earliest_events) front = set(latest_events) - seen_events event_results: List[str] = [] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7996cbb557..73b8aea16c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -469,7 +469,6 @@ class PersistEventsStore: txn: LoggingTransaction, events: List[EventBase], ) -> None: - # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): return diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 584536111d..0a275e6ce6 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): nbrows = 0 last_row_event_id = "" - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) @@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): results = list(txn) # (event_id, parent_id, rel_type) for each relation relations_to_insert: List[Tuple[str, str, str]] = [] - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) except Exception as e: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d0ef10258..b7e7498125 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1493,7 +1493,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(redactions_sql + clause, args) - for (redacter, redacted) in txn: + for redacter, redacted in txn: d = event_dict.get(redacted) if d: d.redactions.append(redacter) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b202c5eb87..fa8be214ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[Dict[str, Any]], int]: - # Set ordering order_by_column = MediaSortOrder(order_by).value diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df53e726e6..fddbc07afa 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -344,7 +344,6 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? @@ -392,7 +391,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, access_token FROM pushers AS p LEFT JOIN access_tokens AS a ON (p.access_token = a.id) @@ -449,7 +447,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey FROM pushers AS p diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dddf49c2d5..92a82240ab 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): def _populate_receipt_event_stream_ordering_txn( txn: LoggingTransaction, ) -> bool: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 644bbb8878..39f89291b2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2168,7 +2168,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): def _get_event_report_txn( txn: LoggingTransaction, report_id: int ) -> Optional[Dict[str, Any]]: - sql = """ SELECT er.id, diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3fe433f66c..a7aae661d8 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore): class SearchBackgroundUpdateStore(SearchWorkerStore): - EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" @@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ba325d390b..ebb2ae964f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7b7d0c3c9..d3393d8e49 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore): insert_cols = [] qargs = [] - for (key, val) in chain( + for key, val in chain( keyvalues.items(), absolutes.items(), additive_relatives.items() ): insert_cols.append(key) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 818c46182e..ac5fbf6b86 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" + # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 6b33d809b6..6d72bd9f67 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == Direction.BACKWARDS: order = "DESC" else: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 30af4b3b6c..c3f2b61bd5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -98,7 +98,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def _populate_user_directory_createtables( self, progress: JsonDict, batch_size: int ) -> int: - # Get all the rooms that we want to process. def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index d743282f13..097dea5182 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 1a7232b276..89b1faa6c8 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -257,14 +257,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( + non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + member_state, incomplete_groups_m = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6c335a9315..2a1c6fa31b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -563,7 +563,7 @@ def _apply_module_schemas( """ # This is the old way for password_auth_provider modules to make changes # to the database. This should instead be done using the module API - for (mod, _config) in config.authproviders.password_providers: + for mod, _config in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) @@ -591,7 +591,7 @@ def _apply_module_schema_files( (modname,), ) applied_deltas = {d for d, in cur} - for (name, stream) in names_and_streams: + for name, stream in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/types/state.py b/synapse/types/state.py index 743a4f9217..4b3071acce 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -120,7 +120,7 @@ class StateFilter: def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: """The inverse to `from_types`.""" - for (event_type, state_keys) in self.types.items(): + for event_type, state_keys in self.types.items(): if state_keys is None: yield event_type, None else: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 9387632d0d..6ffa56217e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -98,7 +98,6 @@ class EvictionReason(Enum): @attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache: Sized _cache_type: str _cache_name: str diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 3b1e205700..1c0fde4966 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -183,7 +183,7 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled = [] errors = [] - for (requirement, must_be_installed) in dependencies: + for requirement, must_be_installed in dependencies: try: dist: metadata.Distribution = metadata.distribution(requirement.name) except metadata.PackageNotFoundError: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index f97f98a057..d00d34e652 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -211,7 +211,6 @@ def _check_yield_points( result = Failure() if current_context() != expected_context: - # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the # deferred we waited on didn't follow the rules, or we forgot to diff --git a/synmark/__main__.py b/synmark/__main__.py index 35a59e347a..19de639187 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -34,12 +34,10 @@ def make_test(main): """ def _main(loops): - reactor = make_reactor() file_out = StringIO() with redirect_stderr(file_out): - d = Deferred() d.addCallback(lambda _: ensureDeferred(main(reactor, loops))) diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index 9419892e95..8beb077e0a 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -30,7 +30,6 @@ from synapse.util import Clock class LineCounter(LineOnlyReceiver): - delimiter = b"\n" def __init__(self, *args, **kwargs): diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 35dd9a20df..33af8770fd 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -24,7 +24,6 @@ from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -37,7 +36,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): return config def test_complexity_simple(self) -> None: - u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") @@ -71,7 +69,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(complexity, 1.23) def test_join_too_large(self) -> None: - u1 = self.register_user("u1", "pass") handler = self.hs.get_room_member_handler() @@ -131,7 +128,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_join_too_large_once_joined(self) -> None: - u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index bba6469b55..6c7738d810 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -34,7 +34,6 @@ from tests.unittest import override_config class FederationServerTests(unittest.FederatingHomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py index 137deab138..d6f43a98fc 100644 --- a/tests/handlers/test_sso.py +++ b/tests/handlers/test_sso.py @@ -113,7 +113,6 @@ async def mock_get_file( headers: Optional[RawHeaders] = None, is_allowed_content_type: Optional[Callable[[str], bool]] = None, ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: - fake_response = FakeResponse(code=404) if url == "http://my.server/me.png": fake_response = FakeResponse( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index f1a50c5bcb..d11ded6c5b 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 class StatsRoomTests(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 7748f56ee6..6ab13357f9 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -46,7 +46,6 @@ class SrvResolverTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: - with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) result: List[Server] diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 9cfe1ad0de..f6d6684985 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -149,7 +149,7 @@ class BlacklistingAgentTest(TestCase): self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" # Configure the reactor's DNS resolver. - for (domain, ip) in ( + for domain, ip in ( (self.safe_domain, self.safe_ip), (self.unsafe_domain, self.unsafe_ip), (self.allowed_domain, self.allowed_ip), diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 199e3d7b70..dce6899e78 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -33,7 +33,6 @@ from tests.unittest import HomeserverTestCase, override_config class TestBulkPushRuleEvaluator(HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7563f33fdc..0a3aca5c50 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -39,7 +39,6 @@ class _User: class EmailPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -48,7 +47,6 @@ class EmailPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["email"] = { "enable_notifs": True, diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index ddca9d696c..57c781a0c3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -64,7 +64,6 @@ def patch__eq__(cls: object) -> Callable[[], None]: class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = EventsWorkerStore def setUp(self) -> None: diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 03f2112b07..aaa488bced 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -28,7 +28,6 @@ from tests import unittest class DeviceRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): class DevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index db77a45ae3..f41319a5b6 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -594,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -724,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -821,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 453a6e979c..9dbb778679 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, room.register_servlets, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index f71ff46d87..28b999573e 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class ServerNoticeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e2ee1a1766..2b05dffc7d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): class DeactivateTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): class WhoamiTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index a144610078..0d8fe77b88 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -52,7 +52,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): class FallbackAuthTests(unittest.HomeserverTestCase): - servlets = [ auth.register_servlets, register.register_servlets, @@ -60,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = True diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index d1751e1557..c16e8d43f4 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -26,7 +26,6 @@ from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index b1ca81a911..bb845179d3 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["form_secret"] = "123abc" diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 7a88aa2cda..6490e883bf 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class DirectoryTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, directory.register_servlets, diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 9fa1f82dfe..f31ebc8021 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -26,7 +26,6 @@ from tests import unittest class EphemeralMessageTestCase(unittest.HomeserverTestCase): - user_id = "@user:test" servlets = [ diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a9b7db9db2..54df2a252c 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = False config["enable_registration"] = True @@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") @@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 830762fd53..91678abf13 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" hijack_auth = True EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index ff5baa9f0a..62acf4f44e 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [ class LoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, ] diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index 6aedc1a11c..b8187db982 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, admin.register_servlets, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 67e16880e6..dcbb125a3b 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -35,7 +35,6 @@ class PresenceTestCase(unittest.HomeserverTestCase): servlets = [presence.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.presence_handler = Mock(spec=PresenceHandler) self.presence_handler.set_state.return_value = make_awaitable(None) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 8de5a342ae..27c93ad761 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -30,7 +30,6 @@ from tests import unittest class ProfileTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 4c561f9525..b228dba861 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register.register_servlets, @@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index c0eb5d01a6..8dbd64be55 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" class RendezvousServletTestCase(unittest.HomeserverTestCase): - servlets = [ rendezvous.register_servlets, ] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index cfad182b2f..4dd763096d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase): servlets = [room.register_servlets, room.register_deprecated_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver( "red", federation_http_client=None, @@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase): rmcreator_id = "@notme:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase): class RoomJoinTestCase(RoomBase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): hijack_auth = False def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Register the user who does the searching self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") @@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() @@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["allow_public_rooms_without_auth"] = True self.hs = self.setup_test_homeserver(config=config) @@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase): class ContextTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): class ThreepidInviteTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -3438,7 +3427,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): """ Test allowing/blocking threepid invites with a spam-check module. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index b9047194dd..9c876c7a32 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -41,7 +41,6 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): class ExcludeRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 5fa3440691..c0f93f898a 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ + # patch the rules module with a Mock which will return False for some event # types async def check( @@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_modify_event(self) -> None: """The module can return a modified version of the event""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index 23f227aed6..b59d9dfd4d 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -31,7 +31,6 @@ from tests.utils import MockClock class MediaRetentionTestCase(unittest.HomeserverTestCase): - ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 17a3b06a8e..8ed27179c4 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -52,7 +52,6 @@ from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): - needs_threadpool = True def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -207,7 +206,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches: List[ Tuple[ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", @@ -268,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2c321f8d04..6fcf60ce19 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["url_preview_enabled"] = True config["max_spider_size"] = 9999999 @@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] @@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: - resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) if hostName not in self.lookups: diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 6540ed53f1..3fdf5a6d52 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -25,7 +25,6 @@ from tests import unittest class ConsentNoticesTests(unittest.HomeserverTestCase): - servlets = [ sync.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -34,7 +33,6 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - tmpdir = self.mktemp() os.mkdir(tmpdir) self.consent_notice_message = "consent %(consent_uri)s" diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 373707b275..b6d5c474b0 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, devices.register_servlets, diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index ac77aec003..71db47405e 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -26,7 +26,6 @@ from tests.unittest import HomeserverTestCase class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -62,6 +61,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): keys and expected receipt key-values after duplicate receipts have been removed. """ + # First, undo the background update. def drop_receipts_unique_index(txn: LoggingTransaction) -> None: txn.execute(f"DROP INDEX IF EXISTS {index_name}") diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 3108ca3444..dbd8f3a85e 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -27,7 +27,6 @@ from tests.unittest import HomeserverTestCase class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 7f7f4ef892..cd0079871c 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -656,7 +656,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index a10e5fa8b1..73d11e7786 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -417,7 +417,6 @@ class EventChainStoreTestCase(HomeserverTestCase): def fetch_chains( self, events: List[EventBase] ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: - # Fetch the map from event ID -> (chain ID, sequence number) rows = self.get_success( self.store.db_pool.simple_select_many_batch( @@ -492,7 +491,6 @@ class LinkMapTestCase(unittest.TestCase): class EventChainBackgroundUpdateTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 8fc7936ab0..3e1984c15c 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -672,7 +672,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): complete_event_dict_map: Dict[str, JsonDict] = {} stream_ordering = 0 - for (event_id, prev_event_ids) in event_graph.items(): + for event_id, prev_event_ids in event_graph.items(): depth = depth_map[event_id] complete_event_dict_map[event_id] = { diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 76c06a9d1e..aa19c3bd30 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -774,7 +774,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual(r, 3) # add a bunch of dummy events to the events table - for (stream_ordering, ts) in ( + for stream_ordering, ts in ( (3, 110), (4, 120), (5, 120), diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index d8f42c5d05..857e2caf2e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase class PurgeTests(HomeserverTestCase): - user_id = "@red:server" servlets = [room.register_servlets] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 8794401823..f4c4661aaf 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -27,7 +27,6 @@ from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register_servlets_for_client_rest_resource, @@ -35,7 +34,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override] - # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastores().main @@ -48,7 +46,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.u_charlie = UserID.from_string("@charlie:elsewhere") def test_one_member(self) -> None: - # Alice creates the room, and is automatically joined self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f730b888f7..e82c03f597 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -242,7 +242,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -259,7 +259,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -272,7 +272,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -289,7 +289,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -309,7 +309,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -327,7 +327,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -341,7 +341,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -392,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -404,7 +404,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertDictEqual({}, state_dict) room_id = self.room.to_string() - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -417,7 +417,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -428,7 +428,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -447,7 +447,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -459,7 +459,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -473,7 +473,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -485,7 +485,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( diff --git a/tests/test_mau.py b/tests/test_mau.py index 4e7665a22b..ff21098a59 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -32,7 +32,6 @@ from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): - servlets = [register.register_servlets, sync.register_servlets] def default_config(self) -> JsonDict: -- cgit 1.5.1 From a068ad7dd4910c81bb0886fbf986dde126eeb4ee Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 23 Feb 2023 19:14:17 +0100 Subject: Add information on uploaded media to user export command. (#15107) --- changelog.d/15107.feature | 1 + docs/usage/administration/admin_faq.md | 74 ++++++++++++++++++++++++++-------- synapse/app/admin_cmd.py | 10 +++++ synapse/handlers/admin.py | 38 +++++++++++++++++ tests/handlers/test_admin.py | 29 +++++++++++++ 5 files changed, 136 insertions(+), 16 deletions(-) create mode 100644 changelog.d/15107.feature (limited to 'synapse') diff --git a/changelog.d/15107.feature b/changelog.d/15107.feature new file mode 100644 index 0000000000..2bdb6a29fc --- /dev/null +++ b/changelog.d/15107.feature @@ -0,0 +1 @@ +Add media information to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.79/usage/administration/admin_faq.html#how-can-i-export-user-data). \ No newline at end of file diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md index 925e1d175e..28c3dd53a5 100644 --- a/docs/usage/administration/admin_faq.md +++ b/docs/usage/administration/admin_faq.md @@ -70,13 +70,55 @@ output-directory │ ├───state │ ├───invite_state │ └───knock_state -└───user_data - ├───account_data - │ ├───global - │ └─── - ├───connections - ├───devices - └───profile +├───user_data +│ ├───account_data +│ │ ├───global +│ │ └─── +│ ├───connections +│ ├───devices +│ └───profile +└───media_ids + └─── +``` + +The `media_ids` folder contains only the metadata of the media uploaded by the user. +It does not contain the media itself. +Furthermore, only the `media_ids` that Synapse manages itself are exported. +If another media repository (e.g. [matrix-media-repo](https://github.com/turt2live/matrix-media-repo)) +is used, the data must be exported separately. + +With the `media_ids` the media files can be downloaded. +Media that have been sent in encrypted rooms are only retrieved in encrypted form. +The following script can help with download the media files: + +```bash +#!/usr/bin/env bash + +# Parameters +# +# source_directory: Directory which contains the export with the media_ids. +# target_directory: Directory into which all files are to be downloaded. +# repository_url: Address of the media repository resp. media worker. +# serverName: Name of the server (`server_name` from homeserver.yaml). +# +# Example: +# ./download_media.sh /tmp/export_data/media_ids/ /tmp/export_data/media_files/ http://localhost:8008 matrix.example.com + +source_directory=$1 +target_directory=$2 +repository_url=$3 +serverName=$4 + +mkdir -p $target_directory + +for file in $source_directory/*; do + filename=$(basename ${file}) + url=$repository_url/_matrix/media/v3/download/$serverName/$filename + echo "Downloading $filename - $url" + if ! wget -o /dev/null -P $target_directory $url; then + echo "Could not download $filename" + fi +done ``` Manually resetting passwords @@ -87,7 +129,7 @@ can reset a user's password using the [admin API](../../admin_api/user_admin_api I have a problem with my server. Can I just delete my database and start again? --- -Deleting your database is unlikely to make anything better. +Deleting your database is unlikely to make anything better. It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated @@ -102,7 +144,7 @@ Come and seek help in https://matrix.to/#/#synapse:matrix.org. There are two exceptions when it might be sensible to delete your database and start again: * You have *never* joined any rooms which are federated with other servers. For -instance, a local deployment which the outside world can't talk to. +instance, a local deployment which the outside world can't talk to. * You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database. @@ -115,7 +157,7 @@ Using the following curl command: curl -H 'Authorization: Bearer ' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/ ``` `` - can be obtained in riot by looking in the riot settings, down the bottom is: -Access Token:\ +Access Token:\ `` - the room alias, eg. #my_room:matrix.org this possibly needs to be URL encoded also, for example %23my_room%3Amatrix.org @@ -152,13 +194,13 @@ What are the biggest rooms on my server? --- ```sql -SELECT s.canonical_alias, g.room_id, count(*) AS num_rows -FROM - state_groups_state AS g, - room_stats_state AS s -WHERE g.room_id = s.room_id +SELECT s.canonical_alias, g.room_id, count(*) AS num_rows +FROM + state_groups_state AS g, + room_stats_state AS s +WHERE g.room_id = s.room_id GROUP BY s.canonical_alias, g.room_id -ORDER BY num_rows desc +ORDER BY num_rows desc LIMIT 10; ``` diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 5003777f0d..b05fe2c589 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -44,6 +44,7 @@ from synapse.storage.databases.main.event_push_actions import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore @@ -86,6 +87,7 @@ class AdminCmdSlavedStore( RegistrationWorkerStore, RoomWorkerStore, ProfileWorkerStore, + MediaRepositoryStore, ): def __init__( self, @@ -235,6 +237,14 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(account_data_file, "a") as f: json.dump(account_data, fp=f) + def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + file_directory = os.path.join(self.base_directory, "media_ids") + os.makedirs(file_directory, exist_ok=True) + media_id_file = os.path.join(file_directory, media_id) + + with open(media_id_file, "w") as f: + json.dump(media_metadata, fp=f) + def finished(self) -> str: return self.base_directory diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 8b7760b2cc..b06f25b03c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -252,16 +252,19 @@ class AdminHandler: profile = await self.get_user(UserID.from_string(user_id)) if profile is not None: writer.write_profile(profile) + logger.info("[%s] Written profile", user_id) # Get all devices the user has devices = await self._device_handler.get_devices_by_user(user_id) writer.write_devices(devices) + logger.info("[%s] Written %s devices", user_id, len(devices)) # Get all connections the user has connections = await self.get_whois(UserID.from_string(user_id)) writer.write_connections( connections["devices"][""]["sessions"][0]["connections"] ) + logger.info("[%s] Written %s connections", user_id, len(connections)) # Get all account data the user has global and in rooms global_data = await self._store.get_global_account_data_for_user(user_id) @@ -269,6 +272,29 @@ class AdminHandler: writer.write_account_data("global", global_data) for room_id in by_room_data: writer.write_account_data(room_id, by_room_data[room_id]) + logger.info( + "[%s] Written account data for %s rooms", user_id, len(by_room_data) + ) + + # Get all media ids the user has + limit = 100 + start = 0 + while True: + media_ids, total = await self._store.get_local_media_by_user_paginate( + start, limit, user_id + ) + for media in media_ids: + writer.write_media_id(media["media_id"], media) + + logger.info( + "[%s] Written %d media_ids of %s", + user_id, + (start + len(media_ids)), + total, + ) + if (start + limit) >= total: + break + start += limit return writer.finished() @@ -359,6 +385,18 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): """ raise NotImplementedError() + @abc.abstractmethod + def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + """Write the media's metadata of a user. + Exports only the metadata, as this can be fetched from the database via + read only. In order to access the files, a connection to the correct + media repository would be required. + + Args: + media_id: ID of the media. + media_metadata: Metadata of one media file. + """ + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 1b97aaeed1..5569ccef8a 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules from synapse.api.room_versions import RoomVersions from synapse.rest.client import knock, login, room from synapse.server import HomeServer +from synapse.types import UserID from synapse.util import Clock from tests import unittest @@ -323,3 +324,31 @@ class ExfiltrateData(unittest.HomeserverTestCase): args = writer.write_account_data.call_args_list[1][0] self.assertEqual(args[0], "test_room") self.assertEqual(args[1]["m.per_room"]["b"], 2) + + def test_media_ids(self) -> None: + """Tests that media's metadata get exported.""" + + self.get_success( + self._store.store_local_media( + media_id="media_1", + media_type="image/png", + time_now_ms=self.clock.time_msec(), + upload_name=None, + media_length=50, + user_id=UserID.from_string(self.user2), + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_media_id.assert_called_once() + + args = writer.write_media_id.call_args[0] + self.assertEqual(args[0], "media_1") + self.assertEqual(args[1]["media_id"], "media_1") + self.assertEqual(args[1]["media_length"], 50) + self.assertGreater(args[1]["created_ts"], 0) + self.assertIsNone(args[1]["upload_name"]) + self.assertIsNone(args[1]["last_access_ts"]) -- cgit 1.5.1 From ec79870f1422be47e8d6e85f315799888278969b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Feb 2023 16:06:42 -0500 Subject: Fix a typo in MSC3873 config option. (#15138) Previously the experimental configuration option referred to the wrong MSC number. --- changelog.d/15138.misc | 1 + synapse/config/experimental.py | 4 ++-- synapse/push/bulk_push_rule_evaluator.py | 12 ++++++------ tests/push/test_push_rule_evaluator.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/15138.misc (limited to 'synapse') diff --git a/changelog.d/15138.misc b/changelog.d/15138.misc new file mode 100644 index 0000000000..fb706b27f2 --- /dev/null +++ b/changelog.d/15138.misc @@ -0,0 +1 @@ +Fix a typo in an experimental config setting. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 54c91953e1..bc38fae0b6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -175,8 +175,8 @@ class ExperimentalConfig(Config): ) # MSC3873: Disambiguate event_match keys. - self.msc3783_escape_event_match_key = experimental.get( - "msc3783_escape_event_match_key", False + self.msc3873_escape_event_match_key = experimental.get( + "msc3873_escape_event_match_key", False ) # MSC3952: Intentional mentions, this depends on MSC3758. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8f834be774..3c4a152d6b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -276,7 +276,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events[relation_type] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) reply_event_id = ( @@ -294,7 +294,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events["m.in_reply_to"] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) # indicate that this is from a fallback relation. @@ -412,7 +412,7 @@ class BulkPushRuleEvaluator: evaluator = PushRuleEvaluator( _flatten_dict( event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ), has_mentions, user_mentions, @@ -507,7 +507,7 @@ def _flatten_dict( prefix: Optional[List[str]] = None, result: Optional[Dict[str, JsonValue]] = None, *, - msc3783_escape_event_match_key: bool = False, + msc3873_escape_event_match_key: bool = False, ) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, @@ -536,7 +536,7 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): - if msc3783_escape_event_match_key: + if msc3873_escape_event_match_key: # Escape periods in the key with a backslash (and backslashes with an # extra backslash). This is since a period is used as a separator between # nested fields. @@ -552,7 +552,7 @@ def _flatten_dict( value, prefix=(prefix + [key]), result=result, - msc3783_escape_event_match_key=msc3783_escape_event_match_key, + msc3873_escape_event_match_key=msc3873_escape_event_match_key, ) # `room_version` should only ever be set when looking at the top level of an event diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index d320a12f96..4e858fd16f 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -54,7 +54,7 @@ class FlattenDictTestCase(unittest.TestCase): self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input)) self.assertEqual( {"m\\.foo.b\\\\ar": "abc"}, - _flatten_dict(input, msc3783_escape_event_match_key=True), + _flatten_dict(input, msc3873_escape_event_match_key=True), ) def test_non_string(self) -> None: -- cgit 1.5.1 From f8a584ed0259cbb3c3a51726d1008d04c26b4bd8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Feb 2023 16:07:46 -0500 Subject: Stop parsing the unspecced type parameter on thumbnail requests. (#15137) Ideally we would replace this with parsing of the Accept header or something else, but for now just make Synapse spec compliant by ignoring the unspecced parameter. It does not seem that this is ever sent by a client, and even if it is there's a reasonable fallback. --- changelog.d/15137.removal | 1 + synapse/rest/media/v1/thumbnail_resource.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/15137.removal (limited to 'synapse') diff --git a/changelog.d/15137.removal b/changelog.d/15137.removal new file mode 100644 index 0000000000..c533b0c9dd --- /dev/null +++ b/changelog.d/15137.removal @@ -0,0 +1 @@ +Remove the undocumented and unspecced `type` parameter to the `/thumbnail` endpoint. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 5f725c7600..3e720018b3 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -69,7 +69,8 @@ class ThumbnailResource(DirectServeJsonResource): width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") - m_type = parse_string(request, "type", "image/png") + # TODO Parse the Accept header to get an prioritised list of thumbnail types. + m_type = "image/png" if server_name == self.server_name: if self.dynamic_thumbnails: -- cgit 1.5.1 From 682151a464f688768d5bd8308e16bd4024ad2e57 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Feb 2023 16:08:53 -0500 Subject: Do not fail completely if oEmbed autodiscovery fails. (#15092) Previously if an autodiscovered oEmbed request failed (e.g. the oEmbed endpoint is down or does not exist) then the entire URL preview would fail. Instead we now return everything we can, even if this additional request fails. --- changelog.d/15092.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 33 ++++++++++++++------ tests/rest/media/v1/test_url_preview.py | 44 +++++++++++++++++++++++++-- 3 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 changelog.d/15092.bugfix (limited to 'synapse') diff --git a/changelog.d/15092.bugfix b/changelog.d/15092.bugfix new file mode 100644 index 0000000000..67509c5c69 --- /dev/null +++ b/changelog.d/15092.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a URL preview would break if the discovered oEmbed failed to download. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a8f6fd6b35..4a594ab9d8 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -163,6 +163,10 @@ class PreviewUrlResource(DirectServeJsonResource): 7. Stores the result in the database cache. 4. Returns the result. + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + The in-memory cache expires after 1 hour. Expired entries in the database cache (and their associated media files) are @@ -364,16 +368,25 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 6fcf60ce19..2acfccec61 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -657,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """If the preview image doesn't exist, ensure some data is returned.""" self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - end_content = ( + result = ( b"""""" ) @@ -678,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) - % (len(end_content),) - + end_content + % (len(result),) + + result ) self.pump() @@ -688,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase): # The image should not be in the result. self.assertNotIn("og:image", channel.json_body) + def test_oembed_failure(self) -> None: + """If the autodiscovered oEmbed URL fails, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + result = b""" + oEmbed Autodiscovery Fail + + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail") + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. -- cgit 1.5.1 From 335f52d595c2c32e4b512b97e2851bc98b819ca7 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 24 Feb 2023 13:39:45 +0000 Subject: Improve handling of non-ASCII characters in user directory search (#15143) * Fix a long-standing bug where non-ASCII characters in search terms, including accented letters, would not match characters in a different case. * Fix a long-standing bug where search terms using combining accents would not match display names using precomposed accents and vice versa. To fully take effect, the user directory must be rebuilt after this change. Fixes #14630. Signed-off-by: Sean Quah --- changelog.d/15143.misc | 1 + synapse/storage/databases/main/user_directory.py | 52 ++++++++- tests/storage/test_user_directory.py | 133 +++++++++++++++++++++++ 3 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 changelog.d/15143.misc (limited to 'synapse') diff --git a/changelog.d/15143.misc b/changelog.d/15143.misc new file mode 100644 index 0000000000..cff4518811 --- /dev/null +++ b/changelog.d/15143.misc @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory search was not case-insensitive for accented characters. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index c3f2b61bd5..f16a509ac4 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -14,6 +14,7 @@ import logging import re +import unicodedata from typing import ( TYPE_CHECKING, Iterable, @@ -490,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): values={"display_name": display_name, "avatar_url": avatar_url}, ) + # The display name that goes into the database index. + index_display_name = display_name + if index_display_name is not None: + index_display_name = _filter_text_for_index(index_display_name) + if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name @@ -507,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, + index_display_name, ), ) elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name) if display_name else user_id + value = ( + "%s %s" % (user_id, index_display_name) + if index_display_name + else user_id + ) self.db_pool.simple_upsert_txn( txn, table="user_directory_search", @@ -896,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return {"limited": limited, "results": results[0:limit]} +def _filter_text_for_index(text: str) -> str: + """Transforms text before it is inserted into the user directory index, or searched + for in the user directory index. + + Note that the user directory search table needs to be rebuilt whenever this function + changes. + """ + # Lowercase the text, to make searches case-insensitive. + # This is necessary for both PostgreSQL and SQLite. PostgreSQL's + # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using + # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at + # all. + text = text.lower() + + # Normalize the text. NFKC normalization has two effects: + # 1. It canonicalizes the text, ie. maps all visually identical strings to the same + # string. For example, ["e", "◌́"] is mapped to ["é"]. + # 2. It maps strings that are roughly equivalent to the same string. + # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to + # ["i", "9"]. + text = unicodedata.normalize("NFKC", text) + + # Note that nothing is done to make searches accent-insensitive. + # That could be achieved by converting to NFKD form instead (with combining accents + # split out) and filtering out combining accents using `unicodedata.combining(c)`. + # The downside of this may be noisier search results, since search terms with + # explicit accents will match characters with no accents, or completely different + # accents. + # + # text = unicodedata.normalize("NFKD", text) + # text = "".join([c for c in text if not unicodedata.combining(c)]) + + return text + + def _parse_query_sqlite(search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. @@ -905,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str: We specifically add both a prefix and non prefix matching term so that exact matches get ranked higher. """ + search_term = _filter_text_for_index(search_term) # Pull out the individual words, discarding any non-word characters. results = _parse_words(search_term) @@ -917,6 +963,8 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ + search_term = _filter_text_for_index(search_term) + escaped_words = [] for word in _parse_words(search_term): # Postgres tsvector and tsquery quoting rules: diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 2d169684cf..43b724c4dd 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -504,6 +504,139 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_ascii_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case. + """ + CHARLIE = "@someuser:example.org" + self.get_success( + self.store.update_profile_in_user_dir(CHARLIE, "Charlie", None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, "cHARLIE", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": CHARLIE, "display_name": "Charlie", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_unicode_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case. + """ + IVAN = "@someuser:example.org" + self.get_success(self.store.update_profile_in_user_dir(IVAN, "Иван", None)) + + r = self.get_success(self.store.search_user_dir(ALICE, "иВАН", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": IVAN, "display_name": "Иван", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_dotted_dotless_i_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case, when their name contains dotted or dotless "i"s. + + Some languages have dotted and dotless versions of "i", which are considered to + be different letters: i <-> İ, ı <-> I. To make things difficult, they reuse the + ASCII "i" and "I" code points, despite having different lowercase / uppercase + forms. + """ + USER = "@someuser:example.org" + + expected_matches = [ + # (search_term, display_name) + # A search for "i" should match "İ". + ("iiiii", "İİİİİ"), + # A search for "I" should match "ı". + ("IIIII", "ııııı"), + # A search for "ı" should match "I". + ("ııııı", "IIIII"), + # A search for "İ" should match "i". + ("İİİİİ", "iiiii"), + ] + + for search_term, display_name in expected_matches: + self.get_success( + self.store.update_profile_in_user_dir(USER, display_name, None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10)) + self.assertFalse(r["limited"]) + self.assertEqual( + 1, + len(r["results"]), + f"searching for {search_term!r} did not match {display_name!r}", + ) + self.assertDictEqual( + r["results"][0], + {"user_id": USER, "display_name": display_name, "avatar_url": None}, + ) + + # We don't test for negative matches, to allow implementations that consider all + # the i variants to be the same. + + test_search_user_dir_dotted_dotless_i_case_insensitivity.skip = "not supported" # type: ignore + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_unicode_normalization(self) -> None: + """Tests that a user can look up another user by searching for their name with + either composed or decomposed accents. + """ + AMELIE = "@someuser:example.org" + + expected_matches = [ + # (search_term, display_name) + ("Ame\u0301lie", "Amélie"), + ("Amélie", "Ame\u0301lie"), + ] + + for search_term, display_name in expected_matches: + self.get_success( + self.store.update_profile_in_user_dir(AMELIE, display_name, None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10)) + self.assertFalse(r["limited"]) + self.assertEqual( + 1, + len(r["results"]), + f"searching for {search_term!r} did not match {display_name!r}", + ) + self.assertDictEqual( + r["results"][0], + {"user_id": AMELIE, "display_name": display_name, "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_accent_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name + without any accents. + """ + AMELIE = "@someuser:example.org" + self.get_success(self.store.update_profile_in_user_dir(AMELIE, "Amélie", None)) + + r = self.get_success(self.store.search_user_dir(ALICE, "amelie", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": AMELIE, "display_name": "Amélie", "avatar_url": None}, + ) + + # It may be desirable for "é"s in search terms to not match plain "e"s and we + # really don't want "é"s in search terms to match "e"s with different accents. + # But we don't test for this to allow implementations that consider all + # "e"-lookalikes to be the same. + + test_search_user_dir_accent_insensitivity.skip = "not supported yet" # type: ignore + class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase): use_icu = True -- cgit 1.5.1 From b2357a898cdd1f4a2222609abfe471801ea88dcd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 24 Feb 2023 14:39:50 +0000 Subject: Fix bug where 5s delays would occasionally happen. (#15150) This only affects deployments using workers. --- changelog.d/15150.bugfix | 1 + synapse/replication/tcp/resource.py | 18 +++++++++++ tests/replication/tcp/test_handler.py | 61 +++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 changelog.d/15150.bugfix (limited to 'synapse') diff --git a/changelog.d/15150.bugfix b/changelog.d/15150.bugfix new file mode 100644 index 0000000000..8668bc587f --- /dev/null +++ b/changelog.d/15150.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.76 where 5s delays would occasionally occur in deployments using workers. diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 9d17eff714..347467d863 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -238,6 +238,24 @@ class ReplicationStreamer: except Exception: logger.exception("Failed to replicate") + # The last token we send may not match the current + # token, in which case we want to send out a `POSITION` + # to tell other workers the actual current position. + if updates[-1][0] < current_token: + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + updates[-1][0], + current_token, + ) + ) + logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index bf927beb6a..bab77b2df7 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -141,3 +141,64 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): self.get_success(ctx_worker1.__aexit__(None, None, None)) self.assertTrue(d.called) + + def test_wait_for_stream_position_rdata(self) -> None: + """Check that wait for stream position correctly waits for an update + from the correct instance, when RDATA is sent. + """ + store = self.hs.get_datastores().main + cmd_handler = self.hs.get_replication_command_handler() + data_handler = self.hs.get_replication_data_handler() + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + cache_id_gen = worker1.get_datastores().main._cache_id_gen + assert cache_id_gen is not None + + self.replicate() + + # First, make sure the master knows that `worker1` exists. + initial_token = cache_id_gen.get_current_token() + cmd_handler.send_command( + PositionCommand("caches", "worker1", initial_token, initial_token) + ) + self.replicate() + + # `wait_for_stream_position` should only return once master receives a + # notification that `next_token2` has persisted. + ctx_worker1 = cache_id_gen.get_next_mult(2) + next_token1, next_token2 = self.get_success(ctx_worker1.__aenter__()) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token2) + ) + self.assertFalse(d.called) + + # Insert an entry into the cache stream with token `next_token1`, but + # not `next_token2`. + self.get_success( + store.db_pool.simple_insert( + table="cache_invalidation_stream_by_instance", + values={ + "stream_id": next_token1, + "instance_name": "worker1", + "cache_func": "foo", + "keys": [], + "invalidation_ts": 0, + }, + ) + ) + + # Finish the context manager, triggering the data to be sent to master. + self.get_success(ctx_worker1.__aexit__(None, None, None)) + + # Master should get told about `next_token2`, so the deferred should + # resolve. + self.assertTrue(d.called) -- cgit 1.5.1 From 1c95ddd09bbc46046a3412e7bb03a87aa3b6f65a Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 24 Feb 2023 13:15:29 -0800 Subject: Batch up storing state groups when creating new room (#14918) --- changelog.d/14918.misc | 1 + synapse/events/snapshot.py | 49 +++++++++++ synapse/handlers/message.py | 16 ++-- synapse/handlers/room.py | 37 ++++---- synapse/handlers/room_batch.py | 4 +- synapse/handlers/room_member.py | 13 ++- synapse/storage/databases/state/store.py | 119 ++++++++++++++++++++++++++ tests/handlers/test_message.py | 25 ++++-- tests/handlers/test_register.py | 3 +- tests/push/test_bulk_push_rule_evaluator.py | 13 +-- tests/rest/client/test_rooms.py | 4 +- tests/storage/test_event_chain.py | 6 +- tests/storage/test_state.py | 126 ++++++++++++++++++++++++++++ tests/unittest.py | 4 +- 14 files changed, 371 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14918.misc (limited to 'synapse') diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc new file mode 100644 index 0000000000..828794354a --- /dev/null +++ b/changelog.d/14918.misc @@ -0,0 +1 @@ +Batch up storing state groups when creating a new room. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e0d82ad81c..a91a5d1e3c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index aa90d0000d..e433d6b01f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -574,7 +574,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -721,8 +721,6 @@ class EventCreationHandler: current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -739,7 +737,7 @@ class EventCreationHandler: assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -764,8 +762,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - - return event, context + return event, unpersisted_context async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1002,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1016,6 +1013,7 @@ class EventCreationHandler: historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1190,7 +1188,6 @@ class EventCreationHandler: if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2046,7 +2043,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2055,6 +2052,7 @@ class EventCreationHandler: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a26ec02284..b1784638f4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,6 +51,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ class RoomCreationHandler: # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, ) = await self.event_creation_handler.create_event( requester, { @@ -225,6 +226,9 @@ class RoomCreationHandler: }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -1092,7 +1096,7 @@ class RoomCreationHandler: content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: """ Creates an event and associated event context. Args: @@ -1111,20 +1115,23 @@ class RoomCreationHandler: event_dict = create_event_dict(etype, content, **kwargs) - new_event, new_context = await self.event_creation_handler.create_event( + ( + new_event, + new_unpersisted_context, + ) = await self.event_creation_handler.create_event( creator, event_dict, prev_event_ids=prev_event, depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context + return new_event, new_unpersisted_context try: config = self._presets_dict[preset_config] @@ -1134,10 +1141,10 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - creation_event, creation_context = await create_event( + creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1181,7 +1188,6 @@ class RoomCreationHandler: power_event, power_context = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { @@ -1230,14 +1236,12 @@ class RoomCreationHandler: power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: @@ -1246,7 +1250,6 @@ class RoomCreationHandler: {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: @@ -1255,7 +1258,6 @@ class RoomCreationHandler: {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: @@ -1265,14 +1267,12 @@ class RoomCreationHandler: {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): event, context = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if config["encrypted"]: @@ -1284,9 +1284,16 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 5d4ca0e2d2..bf9df60218 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,7 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context = await self.event_creation_handler.create_event( + event, unpersisted_context = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +345,7 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a965c7ec76..de7476f300 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, event_dict, txn_id=txn_id, @@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True result_event = ( diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 89b1faa6c8..bf4cdfdf29 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], + room_id: str, + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + sg_before = prev_group + state_group_iter = iter(state_groups) + for event, context in events_and_context: + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + continue + + sg_after = next(state_group_iter) + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + sg_before = sg_after + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context.state_group_after_event, room_id, event.event_id) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + ( + context.state_group_after_event, + context.state_group_before_event, + ) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], + ) + return events_and_context + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 69d384442f..9691d66b48 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): return memberEvent, memberEventContext - def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + def _create_duplicate_event( + self, txn_id: str + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ @@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_suitably_random" - event1, context = self._create_duplicate_event(txn_id) + event1, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( self.handler.handle_new_client_event( @@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(event1.event_id, ret_event1.event_id) - event2, context = self._create_duplicate_event(txn_id) + event2, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, # so we want to make sure we test with different events. @@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_event` directly also does the right # thing. - event3, context = self._create_duplicate_event(txn_id) + event3, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event3)) + self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( @@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_events` directly also does the right # thing. - event4, context = self._create_duplicate_event(txn_id) + event4, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event4)) self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, context1 = self._create_duplicate_event(txn_id) - event2, context2 = self._create_duplicate_event(txn_id) + event1, unpersisted_context1 = self._create_duplicate_event(txn_id) + context1 = self.get_success(unpersisted_context1.persist(event1)) + event2, unpersisted_context2 = self._create_duplicate_event(txn_id) + context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with self.assertNotEqual(event1.event_id, event2.event_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1db99b3c00..aff1ec4758 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_event( requester, { @@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_creation_handler.handle_new_client_event( requester, events_and_context=[(event, context)] diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index dce6899e78..1458076a90 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): prev_event_ids=[pl_event_id], ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise @@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Mock the method which calculates push rules -- we do this instead of @@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) - + context = self.get_success(unpersisted_context.persist(event)) # Execute the push rule machinery. self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4dd763096d..a4900703c4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 73d11e7786..e39b63edac 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] @@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): assert state_ids1 is not None state1 = set(state_ids1.values()) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index e82c03f597..62aed6af0a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + def test_batched_state_group_storing(self) -> None: + creation_event = self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, "", {} + ) + state_to_event = self.get_success( + self.storage.state.get_state_groups( + self.room.to_string(), [creation_event.event_id] + ) + ) + current_state_group = list(state_to_event.keys())[0] + + # create some unpersisted events and event contexts to store against room + events_and_context = [] + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Name, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"name": "first rename of room"}, + }, + ) + + event1, unpersisted_context1 = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + events_and_context.append((event1, unpersisted_context1)) + + builder2 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "private"}, + }, + ) + + event2, unpersisted_context2 = self.get_success( + self.event_creation_handler.create_new_client_event(builder2) + ) + events_and_context.append((event2, unpersisted_context2)) + + builder3 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Message, + "sender": self.u_alice.to_string(), + "room_id": self.room.to_string(), + "content": {"body": "hello from event 3", "msgtype": "m.text"}, + }, + ) + + event3, unpersisted_context3 = self.get_success( + self.event_creation_handler.create_new_client_event(builder3) + ) + events_and_context.append((event3, unpersisted_context3)) + + builder4 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "public"}, + }, + ) + + event4, unpersisted_context4 = self.get_success( + self.event_creation_handler.create_new_client_event(builder4) + ) + events_and_context.append((event4, unpersisted_context4)) + + processed_events_and_context = self.get_success( + self.hs.get_datastores().state.store_state_deltas_for_batched( + events_and_context, self.room.to_string(), current_state_group + ) + ) + + # check that only state events are in state_groups, and all state events are in state_groups + res = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ) + + events = [] + for result in res: + self.assertNotIn(event3.event_id, result) + events.append(result.get("event_id")) + + for event, _ in processed_events_and_context: + if event.is_state(): + self.assertIn(event.event_id, events) + + # check that each unique state has state group in state_groups_state and that the + # type/state key is correct, and check that each state event's state group + # has an entry and prev event in state_group_edges + for event, context in processed_events_and_context: + if event.is_state(): + state = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ) + self.assertEqual(event.type, state[0].get("type")) + self.assertEqual(event.state_key, state[0].get("state_key")) + + groups = self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={"state_group": str(context.state_group_after_event)}, + retcols=("*",), + ) + ) + self.assertEqual( + context.state_group_before_event, groups[0].get("prev_state_group") + ) diff --git a/tests/unittest.py b/tests/unittest.py index b21e7f1221..f9160faa1d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase): event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creator.create_event( requester, { @@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase): prev_event_ids=prev_event_ids, ) ) - + context = self.get_success(unpersisted_context.persist(event)) if soft_failed: event.internal_metadata.soft_failed = True -- cgit 1.5.1 From 3f2ef205e228282a8a744db59115caa4b17da9a1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 27 Feb 2023 13:03:22 +0000 Subject: Small fixes to `MatrixFederationHttpClient` docstrings (#15148) --- changelog.d/15148.doc | 1 + synapse/http/matrixfederationclient.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 changelog.d/15148.doc (limited to 'synapse') diff --git a/changelog.d/15148.doc b/changelog.d/15148.doc new file mode 100644 index 0000000000..4e9e163306 --- /dev/null +++ b/changelog.d/15148.doc @@ -0,0 +1 @@ +Correct small documentation errors in some `MatrixFederationHttpClient` methods. \ No newline at end of file diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 312aab4dcc..3302d4e48a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -440,7 +440,7 @@ class MatrixFederationHttpClient: Args: request: details of request to be sent - retry_on_dns_fail: true if the request should be retied on DNS failures + retry_on_dns_fail: true if the request should be retried on DNS failures timeout: number of milliseconds to wait for the response headers (including connecting to the server), *for each attempt*. @@ -475,7 +475,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -871,7 +871,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -958,7 +958,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1036,6 +1036,8 @@ class MatrixFederationHttpClient: args: A dictionary used to create query strings, defaults to None. + retry_on_dns_fail: true if the request should be retried on DNS failures + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. @@ -1063,7 +1065,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1141,7 +1143,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1197,7 +1199,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. -- cgit 1.5.1 From 4fc8875876374ec8f97a3b3cc344a4e3abcf769f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Feb 2023 08:26:05 -0500 Subject: Refactor media modules. (#15146) * Removes the `v1` directory from `test.rest.media.v1`. * Moves the non-REST code from `synapse.rest.media.v1` to `synapse.media`. * Flatten the `v1` directory from `synapse.rest.media`, but leave compatiblity with 3rd party media repositories and spam checkers. --- changelog.d/15146.misc | 1 + synapse/_scripts/move_remote_media_to_new_store.py | 2 +- synapse/config/repository.py | 12 +- synapse/events/spamcheck.py | 4 +- synapse/media/_base.py | 479 ++++++++ synapse/media/filepath.py | 410 +++++++ synapse/media/media_repository.py | 1038 ++++++++++++++++ synapse/media/media_storage.py | 374 ++++++ synapse/media/oembed.py | 265 +++++ synapse/media/preview_html.py | 501 ++++++++ synapse/media/storage_provider.py | 181 +++ synapse/media/thumbnailer.py | 221 ++++ synapse/rest/media/config_resource.py | 41 + synapse/rest/media/download_resource.py | 75 ++ synapse/rest/media/media_repository_resource.py | 93 ++ synapse/rest/media/preview_url_resource.py | 869 ++++++++++++++ synapse/rest/media/thumbnail_resource.py | 554 +++++++++ synapse/rest/media/upload_resource.py | 108 ++ synapse/rest/media/v1/_base.py | 470 +------- synapse/rest/media/v1/config_resource.py | 41 - synapse/rest/media/v1/download_resource.py | 76 -- synapse/rest/media/v1/filepath.py | 410 ------- synapse/rest/media/v1/media_repository.py | 1112 ------------------ synapse/rest/media/v1/media_storage.py | 365 +----- synapse/rest/media/v1/oembed.py | 265 ----- synapse/rest/media/v1/preview_html.py | 501 -------- synapse/rest/media/v1/preview_url_resource.py | 871 -------------- synapse/rest/media/v1/storage_provider.py | 172 +-- synapse/rest/media/v1/thumbnail_resource.py | 555 --------- synapse/rest/media/v1/thumbnailer.py | 221 ---- synapse/rest/media/v1/upload_resource.py | 108 -- synapse/server.py | 6 +- tests/media/__init__.py | 13 + tests/media/test_base.py | 38 + tests/media/test_filepath.py | 595 ++++++++++ tests/media/test_html_preview.py | 542 +++++++++ tests/media/test_media_storage.py | 792 +++++++++++++ tests/media/test_oembed.py | 162 +++ tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_user.py | 2 +- tests/rest/media/test_url_preview.py | 1234 ++++++++++++++++++++ tests/rest/media/v1/__init__.py | 13 - tests/rest/media/v1/test_base.py | 38 - tests/rest/media/v1/test_filepath.py | 595 ---------- tests/rest/media/v1/test_html_preview.py | 542 --------- tests/rest/media/v1/test_media_storage.py | 792 ------------- tests/rest/media/v1/test_oembed.py | 162 --- tests/rest/media/v1/test_url_preview.py | 1234 -------------------- 48 files changed, 8612 insertions(+), 8545 deletions(-) create mode 100644 changelog.d/15146.misc create mode 100644 synapse/media/_base.py create mode 100644 synapse/media/filepath.py create mode 100644 synapse/media/media_repository.py create mode 100644 synapse/media/media_storage.py create mode 100644 synapse/media/oembed.py create mode 100644 synapse/media/preview_html.py create mode 100644 synapse/media/storage_provider.py create mode 100644 synapse/media/thumbnailer.py create mode 100644 synapse/rest/media/config_resource.py create mode 100644 synapse/rest/media/download_resource.py create mode 100644 synapse/rest/media/media_repository_resource.py create mode 100644 synapse/rest/media/preview_url_resource.py create mode 100644 synapse/rest/media/thumbnail_resource.py create mode 100644 synapse/rest/media/upload_resource.py delete mode 100644 synapse/rest/media/v1/config_resource.py delete mode 100644 synapse/rest/media/v1/download_resource.py delete mode 100644 synapse/rest/media/v1/filepath.py delete mode 100644 synapse/rest/media/v1/media_repository.py delete mode 100644 synapse/rest/media/v1/oembed.py delete mode 100644 synapse/rest/media/v1/preview_html.py delete mode 100644 synapse/rest/media/v1/preview_url_resource.py delete mode 100644 synapse/rest/media/v1/thumbnail_resource.py delete mode 100644 synapse/rest/media/v1/thumbnailer.py delete mode 100644 synapse/rest/media/v1/upload_resource.py create mode 100644 tests/media/__init__.py create mode 100644 tests/media/test_base.py create mode 100644 tests/media/test_filepath.py create mode 100644 tests/media/test_html_preview.py create mode 100644 tests/media/test_media_storage.py create mode 100644 tests/media/test_oembed.py create mode 100644 tests/rest/media/test_url_preview.py delete mode 100644 tests/rest/media/v1/__init__.py delete mode 100644 tests/rest/media/v1/test_base.py delete mode 100644 tests/rest/media/v1/test_filepath.py delete mode 100644 tests/rest/media/v1/test_html_preview.py delete mode 100644 tests/rest/media/v1/test_media_storage.py delete mode 100644 tests/rest/media/v1/test_oembed.py delete mode 100644 tests/rest/media/v1/test_url_preview.py (limited to 'synapse') diff --git a/changelog.d/15146.misc b/changelog.d/15146.misc new file mode 100644 index 0000000000..8de5f95239 --- /dev/null +++ b/changelog.d/15146.misc @@ -0,0 +1 @@ +Refactor the media modules. diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py index 819afaaca6..0dd36bee20 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -37,7 +37,7 @@ import os import shutil import sys -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.media.filepath import MediaFilePaths logger = logging.getLogger() diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2da40c09f0..ecb3edbe3a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -178,11 +178,13 @@ class ContentRepositoryConfig(Config): for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend - if provider_config["module"] == "file_system": - provider_config["module"] = ( - "synapse.rest.media.v1.storage_provider" - ".FileStorageProviderBackend" - ) + if ( + provider_config["module"] == "file_system" + or provider_config["module"] == "synapse.rest.media.v1.storage_provider" + ): + provider_config[ + "module" + ] = "synapse.media.storage_provider.FileStorageProviderBackend" provider_class, parsed_config = load_module( provider_config, ("media_storage_providers", "" % i) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 623a2c71ea..765c15bb51 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -33,8 +33,8 @@ from typing_extensions import Literal import synapse from synapse.api.errors import Codes from synapse.logging.opentracing import trace -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import ReadableFileWrapper +from synapse.media._base import FileInfo +from synapse.media.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import JsonDict, RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable diff --git a/synapse/media/_base.py b/synapse/media/_base.py new file mode 100644 index 0000000000..ef8334ae25 --- /dev/null +++ b/synapse/media/_base.py @@ -0,0 +1,479 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import urllib +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type + +import attr + +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender +from twisted.web.server import Request + +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name + +logger = logging.getLogger(__name__) + +# list all text content types that will have the charset default to UTF-8 when +# none is given +TEXT_CONTENT_TYPES = [ + "text/css", + "text/csv", + "text/html", + "text/calendar", + "text/plain", + "text/javascript", + "application/json", + "application/ld+json", + "application/rtf", + "image/svg+xml", + "text/xml", +] + + +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ + try: + # The type on postpath seems incorrect in Twisted 21.2.0. + postpath: List[bytes] = request.postpath # type: ignore + assert postpath + + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name_bytes, media_id_bytes = postpath[:2] + server_name = server_name_bytes.decode("utf-8") + media_id = media_id_bytes.decode("utf8") + + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + + file_name = None + if len(postpath) > 2: + try: + file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except Exception: + raise SynapseError( + 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN + ) + + +def respond_404(request: SynapseRequest) -> None: + respond_with_json( + request, + 404, + cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), + send_cors=True, + ) + + +async def respond_with_file( + request: SynapseRequest, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + add_file_headers(request, media_type, file_size, upload_name) + + with open(file_path, "rb") as f: + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + + finish_request(request) + else: + respond_404(request) + + +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. + """ + + def _quote(x: str) -> str: + return urllib.parse.quote(x.encode("utf-8")) + + # Default to a UTF-8 charset for text content types. + # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' + if media_type.lower() in TEXT_CONTENT_TYPES: + content_type = media_type + "; charset=UTF-8" + else: + content_type = media_type + + request.setHeader(b"Content-Type", content_type.encode("UTF-8")) + if upload_name: + # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. + # + # `filename` is defined to be a `value`, which is defined by RFC2616 + # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` + # is (essentially) a single US-ASCII word, and a `quoted-string` is a + # US-ASCII string surrounded by double-quotes, using backslash as an + # escape character. Note that %-encoding is *not* permitted. + # + # `filename*` is defined to be an `ext-value`, which is defined in + # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, + # where `value-chars` is essentially a %-encoded string in the given charset. + # + # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 + # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 + # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 + + # We avoid the quoted-string version of `filename`, because (a) synapse didn't + # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we + # may as well just do the filename* version. + if _can_encode_filename_as_token(upload_name): + disposition = "inline; filename=%s" % (upload_name,) + else: + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) + + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) + + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + + +# separators as defined in RFC2616. SP and HT are handled separately. +# see _can_encode_filename_as_token. +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} + + +def _can_encode_filename_as_token(x: str) -> bool: + for c in x: + # from RFC2616: + # + # token = 1* + # + # separators = "(" | ")" | "<" | ">" | "@" + # | "," | ";" | ":" | "\" | <"> + # | "/" | "[" | "]" | "?" | "=" + # | "{" | "}" | SP | HT + # + # CHAR = + # + # CTL = + # + if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: + return False + return True + + +async def respond_with_responder( + request: SynapseRequest, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + # If we have a responder we *must* use it as a context manager. + with responder: + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + logger.debug("Responding to media request with responder %s", responder) + add_file_headers(request, media_type, file_size, upload_name) + try: + await responder.write_to_consumer(request) + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + + finish_request(request) + + +class Responder(ABC): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + + @abstractmethod + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: + """Stream response into consumer + + Args: + consumer: The consumer to stream into. + + Returns: + Resolves once the response has finished being written + """ + raise NotImplementedError() + + def __enter__(self) -> None: # noqa: B027 + pass + + def __exit__( # noqa: B027 + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThumbnailInfo: + """Details about a generated thumbnail.""" + + width: int + height: int + method: str + # Content type of thumbnail, e.g. image/png + type: str + # The size of the media file, in bytes. + length: Optional[int] = None + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FileInfo: + """Details about a requested/uploaded file.""" + + # The server name where the media originated from, or None if local. + server_name: Optional[str] + # The local ID of the file. For local files this is the same as the media_id + file_id: str + # If the file is for the url preview cache + url_cache: bool = False + # Whether the file is a thumbnail or not. + thumbnail: Optional[ThumbnailInfo] = None + + # The below properties exist to maintain compatibility with third-party modules. + @property + def thumbnail_width(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.width + + @property + def thumbnail_height(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.height + + @property + def thumbnail_method(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.method + + @property + def thumbnail_type(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.type + + @property + def thumbnail_length(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.length + + +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: + """ + Get the filename of the downloaded file by inspecting the + Content-Disposition HTTP header. + + Args: + headers: The HTTP request headers. + + Returns: + The filename, or None. + """ + content_disposition = headers.get(b"Content-Disposition", [b""]) + + # No header, bail out. + if not content_disposition[0]: + return None + + _, params = _parse_header(content_disposition[0]) + + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get(b"filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith(b"utf-8''"): + upload_name_utf8 = upload_name_utf8[7:] + # We have a filename*= section. This MUST be ASCII, and any UTF-8 + # bytes are %-quoted. + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get(b"filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii.decode("ascii") + + # This may be None here, indicating we did not find a matching name. + return upload_name + + +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: + """Parse a Content-type like header. + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + line: header to be parsed + + Returns: + The main content-type, followed by the parameter dictionary + """ + parts = _parseparam(b";" + line) + key = next(parts) + pdict = {} + for p in parts: + i = p.find(b"=") + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1 :].strip() + + # strip double-quotes + if len(value) >= 2 and value[0:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') + pdict[name] = value + + return key, pdict + + +def _parseparam(s: bytes) -> Generator[bytes, None, None]: + """Generator which splits the input on ;, respecting double-quoted sequences + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + s: header to be parsed + + Returns: + The split input + """ + while s[:1] == b";": + s = s[1:] + + # look for the next ; + end = s.find(b";") + + # if there is an odd number of " marks between here and the next ;, skip to the + # next ; instead + while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: + end = s.find(b";", end + 1) + + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] diff --git a/synapse/media/filepath.py b/synapse/media/filepath.py new file mode 100644 index 0000000000..1f6441c412 --- /dev/null +++ b/synapse/media/filepath.py @@ -0,0 +1,410 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import re +import string +from typing import Any, Callable, List, TypeVar, Union, cast + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") + + +F = TypeVar("F", bound=Callable[..., str]) + + +def _wrap_in_base_path(func: F) -> F: + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + + @functools.wraps(func) + def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return cast(F, _wrapped) + + +GetPathMethod = TypeVar( + "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] +) + + +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: + """Wraps a path-returning method to check that the returned path(s) do not escape + the media store directory. + + The path-returning method may return either a single path, or a list of paths. + + The check is not expected to ever fail, unless `func` is missing a call to + `_validate_path_component`, or `_validate_path_component` is buggy. + + Args: + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. + + Returns: + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. + """ + + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner + + +ALLOWED_CHARACTERS = set( + string.ascii_letters + + string.digits + + "_-" + + ".[]:" # Domain names, IPv6 addresses and ports in server names +) +FORBIDDEN_NAMES = { + "", + os.path.curdir, # "." for the current platform + os.path.pardir, # ".." for the current platform +} + + +def _validate_path_component(name: str) -> str: + """Checks that the given string can be safely used as a path component + + Args: + name: The path component to check. + + Returns: + The path component if valid. + + Raises: + ValueError: If `name` cannot be safely used as a path component. + """ + if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: + raise ValueError(f"Invalid path component: {name!r}") + + return name + + +class MediaFilePaths: + """Describes where files are stored on disk. + + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ + + def __init__(self, primary_base_path: str): + self.base_path = primary_base_path + self.normalized_base_path = os.path.normpath(self.base_path) + + # Refuse to initialize if paths cannot be validated correctly for the current + # platform. + assert os.path.sep not in ALLOWED_CHARACTERS + assert os.path.altsep not in ALLOWED_CHARACTERS + # On Windows, paths have all sorts of weirdness which `_validate_path_component` + # does not consider. In any case, the remote media store can't work correctly + # for certain homeservers there, since ":"s aren't allowed in paths. + assert os.name == "posix" + + @_wrap_with_jail_check(relative=True) + def local_media_filepath_rel(self, media_id: str) -> str: + return os.path.join( + "local_content", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + @_wrap_with_jail_check(relative=True) + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + return os.path.join( + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), + ) + + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + @_wrap_with_jail_check(relative=False) + def local_media_thumbnail_dir(self, media_id: str) -> str: + """ + Retrieve the local store path of thumbnails of a given media_id + + Args: + media_id: The media ID to query. + Returns: + Path of local_thumbnails from media_id + """ + return os.path.join( + self.base_path, + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + @_wrap_with_jail_check(relative=True) + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: + return os.path.join( + "remote_content", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + ) + + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + @_wrap_with_jail_check(relative=True) + def remote_media_thumbnail_rel( + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + return os.path.join( + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), + ) + + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + + # Legacy path that was used to store thumbnails previously. + # Should be removed after some time, when most of the thumbnails are stored + # using the new path. + @_wrap_with_jail_check(relative=True) + def remote_media_thumbnail_rel_legacy( + self, server_name: str, file_id: str, width: int, height: int, content_type: str + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + return os.path.join( + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), + ) + + @_wrap_with_jail_check(relative=False) + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: + return os.path.join( + self.base_path, + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + ) + + @_wrap_with_jail_check(relative=True) + def url_cache_filepath_rel(self, media_id: str) -> str: + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + "url_cache", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) + else: + return os.path.join( + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + + @_wrap_with_jail_check(relative=False) + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[:10]) + ) + ] + else: + return [ + os.path.join( + self.base_path, + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[0:2]) + ), + ] + + @_wrap_with_jail_check(relative=True) + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + _validate_path_component(file_name), + ) + else: + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), + ) + + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + + @_wrap_with_jail_check(relative=True) + def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) + else: + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + url_cache_thumbnail_directory = _wrap_in_base_path( + url_cache_thumbnail_directory_rel + ) + + @_wrap_with_jail_check(relative=False) + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + ), + ] + else: + return [ + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + ), + ] diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py new file mode 100644 index 0000000000..b81e3c2b0c --- /dev/null +++ b/synapse/media/media_repository.py @@ -0,0 +1,1038 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import errno +import logging +import os +import shutil +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +from matrix_common.types.mxc_uri import MXCUri + +import twisted.internet.error +import twisted.web.http +from twisted.internet.defer import Deferred + +from synapse.api.errors import ( + FederationDeniedError, + HttpResponseException, + NotFoundError, + RequestSendFailed, + SynapseError, +) +from synapse.config.repository import ThumbnailRequirement +from synapse.http.site import SynapseRequest +from synapse.logging.context import defer_to_thread +from synapse.media._base import ( + FileInfo, + Responder, + ThumbnailInfo, + get_filename_from_headers, + respond_404, + respond_with_responder, +) +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_storage import MediaStorage +from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.thumbnailer import Thumbnailer, ThumbnailError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.async_helpers import Linearizer +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# How often to run the background job to update the "recently accessed" +# attribute of local and remote media. +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute +# How often to run the background job to check for local and remote media +# that should be purged according to the configured media retention settings. +MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour + + +class MediaRepository: + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.client = hs.get_federation_http_client() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.store = hs.get_datastores().main + self.max_upload_size = hs.config.media.max_upload_size + self.max_image_pixels = hs.config.media.max_image_pixels + + Thumbnailer.set_limits(self.max_image_pixels) + + self.primary_base_path: str = hs.config.media.media_store_path + self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) + + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + self.thumbnail_requirements = hs.config.media.thumbnail_requirements + + self.remote_media_linearizer = Linearizer(name="media_remote") + + self.recently_accessed_remotes: Set[Tuple[str, str]] = set() + self.recently_accessed_locals: Set[str] = set() + + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) + + # List of StorageProviders where we should search for media and + # potentially upload to. + storage_providers = [] + + for ( + clz, + provider_config, + wrapper_config, + ) in hs.config.media.media_storage_providers: + backend = clz(hs, provider_config) + provider = StorageProviderWrapper( + backend, + store_local=wrapper_config.store_local, + store_remote=wrapper_config.store_remote, + store_synchronous=wrapper_config.store_synchronous, + ) + storage_providers.append(provider) + + self.media_storage = MediaStorage( + self.hs, self.primary_base_path, self.filepaths, storage_providers + ) + + self.clock.looping_call( + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS + ) + + # Media retention configuration options + self._media_retention_local_media_lifetime_ms = ( + hs.config.media.media_retention_local_media_lifetime_ms + ) + self._media_retention_remote_media_lifetime_ms = ( + hs.config.media.media_retention_remote_media_lifetime_ms + ) + + # Check whether local or remote media retention is configured + if ( + hs.config.media.media_retention_local_media_lifetime_ms is not None + or hs.config.media.media_retention_remote_media_lifetime_ms is not None + ): + # Run the background job to apply media retention rules routinely, + # with the duration between runs dictated by the homeserver config. + self.clock.looping_call( + self._start_apply_media_retention_rules, + MEDIA_RETENTION_CHECK_PERIOD_MS, + ) + + def _start_update_recently_accessed(self) -> Deferred: + return run_as_background_process( + "update_recently_accessed_media", self._update_recently_accessed + ) + + def _start_apply_media_retention_rules(self) -> Deferred: + return run_as_background_process( + "apply_media_retention_rules", self._apply_media_retention_rules + ) + + async def _update_recently_accessed(self) -> None: + remote_media = self.recently_accessed_remotes + self.recently_accessed_remotes = set() + + local_media = self.recently_accessed_locals + self.recently_accessed_locals = set() + + await self.store.update_cached_last_access_time( + local_media, remote_media, self.clock.time_msec() + ) + + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: + """Mark the given media as recently accessed. + + Args: + server_name: Origin server of media, or None if local + media_id: The media ID of the content + """ + if server_name: + self.recently_accessed_remotes.add((server_name, media_id)) + else: + self.recently_accessed_locals.add(media_id) + + async def create_content( + self, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + ) -> MXCUri: + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type: The content type of the file. + upload_name: The name of the file, if provided. + content: A file like object that is the content to store + content_length: The length of the content + auth_user: The user_id of the uploader + + Returns: + The mxc url of the stored content + """ + + media_id = random_string(24) + + file_info = FileInfo(server_name=None, file_id=media_id) + + fname = await self.media_storage.store_file(content, file_info) + + logger.info("Stored local media in file %r", fname) + + await self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + + await self._generate_thumbnails(None, media_id, media_id, media_type) + + return MXCUri(self.server_name, media_id) + + async def get_local_media( + self, request: SynapseRequest, media_id: str, name: Optional[str] + ) -> None: + """Responds to requests for local media, if exists, or returns 404. + + Args: + request: The incoming request. + media_id: The media ID of the content. (This is the same as + the file_id for local content.) + name: Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Resolves once a response has successfully been written to request + """ + media_info = await self.store.get_local_media(media_id) + if not media_info or media_info["quarantined_by"]: + respond_404(request) + return + + self.mark_recently_accessed(None, media_id) + + media_type = media_info["media_type"] + if not media_type: + media_type = "application/octet-stream" + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + url_cache = media_info["url_cache"] + + file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) + + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( + request, responder, media_type, media_length, upload_name + ) + + async def get_remote_media( + self, + request: SynapseRequest, + server_name: str, + media_id: str, + name: Optional[str], + ) -> None: + """Respond to requests for remote media. + + Args: + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Resolves once a response has successfully been written to request + """ + if ( + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + self.mark_recently_accessed(server_name, media_id) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + async with self.remote_media_linearizer.queue(key): + responder, media_info = await self._get_remote_media_impl( + server_name, media_id + ) + + # We deliberately stream the file outside the lock + if responder: + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + await respond_with_responder( + request, responder, media_type, media_length, upload_name + ) + else: + respond_404(request) + + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + """Gets the media info associated with the remote file, downloading + if necessary. + + Args: + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + + Returns: + The media info of the file + """ + if ( + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + async with self.remote_media_linearizer.queue(key): + responder, media_info = await self._get_remote_media_impl( + server_name, media_id + ) + + # Ensure we actually use the responder so that it releases resources + if responder: + with responder: + pass + + return media_info + + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: + """Looks for media in local cache, if not there then attempt to + download from remote server. + + Args: + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the + remote server). + + Returns: + A tuple of responder and the media info of the file. + """ + media_info = await self.store.get_cached_remote_media(server_name, media_id) + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise generate a new + # one. + + # If we have an entry in the DB, try and look for it + if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + raise NotFoundError() + + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + + responder = await self.media_storage.fetch_media(file_info) + if responder: + return responder, media_info + + # Failed to find the file anywhere, lets download it. + + try: + media_info = await self._download_remote_file( + server_name, + media_id, + ) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) + + responder = await self.media_storage.fetch_media(file_info) + return responder, media_info + + async def _download_remote_file( + self, + server_name: str, + media_id: str, + ) -> dict: + """Attempt to download the remote file from the given server name, + using the given file_id as the local id. + + Args: + server_name: Originating server + media_id: The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + file_id: Local file ID + + Returns: + The media info of the file. + """ + + file_id = random_string(24) + + file_info = FileInfo(server_name=server_name, file_id=file_id) + + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + request_path = "/".join( + ("/_matrix/media/r0/download", server_name, media_id) + ) + try: + length, headers = await self.client.get_file( + server_name, + request_path, + output_stream=f, + max_size=self.max_upload_size, + args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false" + }, + ) + except RequestSendFailed as e: + logger.warning( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) + raise SynapseError(502, "Failed to fetch remote media") + + except HttpResponseException as e: + logger.warning( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) + if e.code == twisted.web.http.NOT_FOUND: + raise e.to_synapse_error() + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise + except NotRetryingDestination: + logger.warning("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise SynapseError(502, "Failed to fetch remote media") + + await finish() + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) + + logger.info("Stored remote media in file %r", fname) + + media_info = { + "media_type": media_type, + "media_length": length, + "upload_name": upload_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + } + + return media_info + + def _get_thumbnail_requirements( + self, media_type: str + ) -> Tuple[ThumbnailRequirement, ...]: + scpos = media_type.find(";") + if scpos > 0: + media_type = media_type[:scpos] + return self.thumbnail_requirements.get(media_type, ()) + + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, + ) + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = thumbnailer.transpose() + + if t_method == "crop": + return thumbnailer.crop(t_width, t_height, t_type) + elif t_method == "scale": + t_width, t_height = thumbnailer.aspect(t_width, t_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + return thumbnailer.scale(t_width, t_height, t_type) + + return None + + async def generate_local_exact_thumbnail( + self, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + url_cache: bool, + ) -> Optional[str]: + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(None, media_id, url_cache=url_cache) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None + + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=None, + file_id=media_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + output_path = await self.media_storage.store_file( + t_byte_source, file_info + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) + + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + return output_path + + # Could not generate thumbnail. + return None + + async def generate_remote_exact_thumbnail( + self, + server_name: str, + file_id: str, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[str]: + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None + + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + output_path = await self.media_storage.store_file( + t_byte_source, file_info + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) + + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + + return output_path + + # Could not generate thumbnail. + return None + + async def _generate_thumbnails( + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: + """Generate and store thumbnails for an image. + + Args: + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as + the file_id for local content) + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. + """ + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return None + + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, + ) + return None + + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, + ) + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose + ) + + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: Dict[Tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + with self.media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + return {"width": m_width, "height": m_height} + + async def _apply_media_retention_rules(self) -> None: + """ + Purge old local and remote media according to the media retention rules + defined in the homeserver config. + """ + # Purge remote media + if self._media_retention_remote_media_lifetime_ms is not None: + # Calculate a threshold timestamp derived from the configured lifetime. Any + # media that has not been accessed since this timestamp will be removed. + remote_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms + ) + + logger.info( + "Purging remote media last accessed before" + f" {remote_media_threshold_timestamp_ms}" + ) + + await self.delete_old_remote_media( + before_ts=remote_media_threshold_timestamp_ms + ) + + # And now do the same for local media + if self._media_retention_local_media_lifetime_ms is not None: + # This works the same as the remote media threshold + local_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_local_media_lifetime_ms + ) + + logger.info( + "Purging local media last accessed before" + f" {local_media_threshold_timestamp_ms}" + ) + + await self.delete_old_local_media( + before_ts=local_media_threshold_timestamp_ms, + keep_profiles=True, + delete_quarantined_media=False, + delete_protected_media=False, + ) + + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: + old_media = await self.store.get_remote_media_ids( + before_ts, include_quarantined_media=False + ) + + deleted = 0 + + for media in old_media: + origin = media["media_origin"] + media_id = media["media_id"] + file_id = media["filesystem_id"] + key = (origin, media_id) + + logger.info("Deleting: %r", key) + + # TODO: Should we delete from the backup store + + async with self.remote_media_linearizer.queue(key): + full_path = self.filepaths.remote_media_filepath(origin, file_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r", full_path) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(origin, media_id) + deleted += 1 + + return {"deleted": deleted} + + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete the given local or remote media ID from this server + + Args: + media_id: The media ID to delete. + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + return await self._remove_local_media_from_disk(media_ids) + + async def delete_old_local_media( + self, + before_ts: int, + size_gt: int = 0, + keep_profiles: bool = True, + delete_quarantined_media: bool = False, + delete_protected_media: bool = False, + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server by size and timestamp. Removes + media files, any thumbnails and cached URLs. + + Args: + before_ts: Unix timestamp in ms. + Files that were last used before this timestamp will be deleted. + size_gt: Size of the media in bytes. Files that are larger will be deleted. + keep_profiles: Switch to delete also files that are still used in image data + (e.g user profile, room avatar). If false these files will be deleted. + delete_quarantined_media: If True, media marked as quarantined will be deleted. + delete_protected_media: If True, media marked as protected will be deleted. + + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + old_media = await self.store.get_local_media_ids( + before_ts, + size_gt, + keep_profiles, + include_quarantined_media=delete_quarantined_media, + include_protected_media=delete_protected_media, + ) + return await self._remove_local_media_from_disk(old_media) + + async def _remove_local_media_from_disk( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server. Removes media files, + any thumbnails and cached URLs. + + Args: + media_ids: List of media_id to delete + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + removed_media = [] + for media_id in media_ids: + logger.info("Deleting media with ID '%s'", media_id) + full_path = self.filepaths.local_media_filepath(media_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r: %s", full_path, e) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(self.server_name, media_id) + + await self.store.delete_url_cache((media_id,)) + await self.store.delete_url_cache_media((media_id,)) + + removed_media.append(media_id) + + return removed_media, len(removed_media) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py new file mode 100644 index 0000000000..a7e22a91e1 --- /dev/null +++ b/synapse/media/media_storage.py @@ -0,0 +1,374 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import os +import shutil +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + Awaitable, + BinaryIO, + Callable, + Generator, + Optional, + Sequence, + Tuple, + Type, +) + +import attr + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender + +import synapse +from synapse.api.errors import NotFoundError +from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock +from synapse.util.file_consumer import BackgroundFileConsumer + +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.media.storage_provider import StorageProvider + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MediaStorage: + """Responsible for storing/fetching files from local sources. + + Args: + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. + """ + + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): + self.hs = hs + self.reactor = hs.get_reactor() + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() + + async def store_file(self, source: IO, file_info: FileInfo) -> str: + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info: Info about the file to store + + Returns: + the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + await self.write_to_file(source, f) + await finish_cb() + + return fname + + async def write_to_file(self, source: IO, output: IO) -> None: + """Asynchronously write the `source` to `output`.""" + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + + @contextlib.contextmanager + def store_into_file( + self, file_info: FileInfo + ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns an awaitable. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info: Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + await finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + + finished_called = [False] + + try: + with open(fname, "wb") as f: + + async def finish() -> None: + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + spam_check = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam_check != synapse.module_api.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + + yield f, fname, finish + except Exception as e: + try: + os.remove(fname) + except Exception: + pass + + raise e from None + + if not finished_called: + raise Exception("Finished callback not called") + + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info + + Returns: + Returns a Responder if the file was found, otherwise None. + """ + paths = [self._file_info_to_path(file_info)] + + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) + + for provider in self.storage_providers: + for path in paths: + res: Any = await provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) + + return None + + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info + + Returns: + Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + return local_path + + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + + dirname = os.path.dirname(local_path) + os.makedirs(dirname, exist_ok=True) + + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + return local_path + + raise NotFoundError() + + def _file_info_to_path(self, file_info: FileInfo) -> str: + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) + + +def _write_file_synchronously(source: IO, dest: IO) -> None: + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file: A file like object to be streamed ot the client, + is closed when finished streaming. + """ + + def __init__(self, open_file: IO): + self.open_file = open_file + + def write_to_consumer(self, consumer: IConsumer) -> Deferred: + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True, auto_attribs=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2**14 + + clock: Clock + path: str + + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: + """Reads the file in chunks and calls the callback with each chunk.""" + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py new file mode 100644 index 0000000000..c0eaf04be5 --- /dev/null +++ b/synapse/media/oembed.py @@ -0,0 +1,265 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import html +import logging +import urllib.parse +from typing import TYPE_CHECKING, List, Optional + +import attr + +from synapse.media.preview_html import parse_html_description +from synapse.types import JsonDict +from synapse.util import json_decoder + +if TYPE_CHECKING: + from lxml import etree + + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class OEmbedResult: + # The Open Graph result (converted from the oEmbed result). + open_graph_result: JsonDict + # The author_name of the oEmbed result + author_name: Optional[str] + # Number of milliseconds to cache the content, according to the oEmbed response. + # + # This will be None if no cache-age is provided in the oEmbed response (or + # if the oEmbed response cannot be turned into an Open Graph response). + cache_age: Optional[int] + + +class OEmbedProvider: + """ + A helper for accessing oEmbed content. + + It can be used to check if a URL should be accessed via oEmbed and for + requesting/parsing oEmbed content. + """ + + def __init__(self, hs: "HomeServer"): + self._oembed_patterns = {} + for oembed_endpoint in hs.config.oembed.oembed_patterns: + api_endpoint = oembed_endpoint.api_endpoint + + # Only JSON is supported at the moment. This could be declared in + # the formats field. Otherwise, if the endpoint ends in .xml assume + # it doesn't support JSON. + if ( + oembed_endpoint.formats is not None + and "json" not in oembed_endpoint.formats + ) or api_endpoint.endswith(".xml"): + logger.info( + "Ignoring oEmbed endpoint due to not supporting JSON: %s", + api_endpoint, + ) + continue + + # Iterate through each URL pattern and point it to the endpoint. + for pattern in oembed_endpoint.url_patterns: + self._oembed_patterns[pattern] = api_endpoint + + def get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Args: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in self._oembed_patterns.items(): + if url_pattern.fullmatch(url): + # TODO Specify max height / width. + + # Note that only the JSON format is supported, some endpoints want + # this in the URL, others want it as an argument. + endpoint = endpoint.replace("{format}", "json") + + args = {"url": url, "format": "json"} + query_str = urllib.parse.urlencode(args, True) + return f"{endpoint}?{query_str}" + + # No match. + return None + + def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]: + """ + Search an HTML document for oEmbed autodiscovery information. + + Args: + tree: The parsed HTML body. + + Returns: + The URL to use for oEmbed information, or None if no URL was found. + """ + # Search for link elements with the proper rel and type attributes. + for tag in tree.xpath( + "//link[@rel='alternate'][@type='application/json+oembed']" + ): + if "href" in tag.attrib: + return tag.attrib["href"] + + # Some providers (e.g. Flickr) use alternative instead of alternate. + for tag in tree.xpath( + "//link[@rel='alternative'][@type='application/json+oembed']" + ): + if "href" in tag.attrib: + return tag.attrib["href"] + + return None + + def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: + """ + Parse the oEmbed response into an Open Graph response. + + Args: + url: The URL which is being previewed (not the one which was + requested). + raw_body: The oEmbed response as JSON encoded as bytes. + + Returns: + json-encoded Open Graph data + """ + + try: + # oEmbed responses *must* be UTF-8 according to the spec. + oembed = json_decoder.decode(raw_body.decode("utf-8")) + except ValueError: + return OEmbedResult({}, None, None) + + # The version is a required string field, but not always provided, + # or sometimes provided as a float. Be lenient. + oembed_version = oembed.get("version", "1.0") + if oembed_version != "1.0" and oembed_version != 1: + return OEmbedResult({}, None, None) + + # Attempt to parse the cache age, if possible. + try: + cache_age = int(oembed.get("cache_age")) * 1000 + except (TypeError, ValueError): + # If the cache age cannot be parsed (e.g. wrong type or invalid + # string), ignore it. + cache_age = None + + # The oEmbed response converted to Open Graph. + open_graph_response: JsonDict = {"og:url": url} + + title = oembed.get("title") + if title and isinstance(title, str): + # A common WordPress plug-in seems to incorrectly escape entities + # in the oEmbed response. + open_graph_response["og:title"] = html.unescape(title) + + author_name = oembed.get("author_name") + if not isinstance(author_name, str): + author_name = None + + # Use the provider name and as the site. + provider_name = oembed.get("provider_name") + if provider_name and isinstance(provider_name, str): + open_graph_response["og:site_name"] = provider_name + + # If a thumbnail exists, use it. Note that dimensions will be calculated later. + thumbnail_url = oembed.get("thumbnail_url") + if thumbnail_url and isinstance(thumbnail_url, str): + open_graph_response["og:image"] = thumbnail_url + + # Process each type separately. + oembed_type = oembed.get("type") + if oembed_type == "rich": + html_str = oembed.get("html") + if isinstance(html_str, str): + calc_description_and_urls(open_graph_response, html_str) + + elif oembed_type == "photo": + # If this is a photo, use the full image, not the thumbnail. + url = oembed.get("url") + if url and isinstance(url, str): + open_graph_response["og:image"] = url + + elif oembed_type == "video": + open_graph_response["og:type"] = "video.other" + html_str = oembed.get("html") + if html_str and isinstance(html_str, str): + calc_description_and_urls(open_graph_response, oembed["html"]) + for size in ("width", "height"): + val = oembed.get(size) + if type(val) is int: + open_graph_response[f"og:video:{size}"] = val + + elif oembed_type == "link": + open_graph_response["og:type"] = "website" + + else: + logger.warning("Unknown oEmbed type: %s", oembed_type) + + return OEmbedResult(open_graph_response, author_name, cache_age) + + +def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: + results = [] + for tag in tree.xpath("//*/" + tag_name): + if "src" in tag.attrib: + results.append(tag.attrib["src"]) + return results + + +def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: + """ + Calculate description for an HTML document. + + This uses lxml to convert the HTML document into plaintext. If errors + occur during processing of the document, an empty response is returned. + + Args: + open_graph_response: The current Open Graph summary. This is updated with additional fields. + html_body: The HTML document, as bytes. + + Returns: + The summary + """ + # If there's no body, nothing useful is going to be found. + if not html_body: + return + + from lxml import etree + + # Create an HTML parser. If this fails, log and return no metadata. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(html_body, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return + + # Attempt to find interesting URLs (images, videos, embeds). + if "og:image" not in open_graph_response: + image_urls = _fetch_urls(tree, "img") + if image_urls: + open_graph_response["og:image"] = image_urls[0] + + video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed") + if video_urls: + open_graph_response["og:video"] = video_urls[0] + + description = parse_html_description(tree) + if description: + open_graph_response["og:description"] = description diff --git a/synapse/media/preview_html.py b/synapse/media/preview_html.py new file mode 100644 index 0000000000..516d0434f0 --- /dev/null +++ b/synapse/media/preview_html.py @@ -0,0 +1,501 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import logging +import re +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Union, +) + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + +# Certain elements aren't meant for display. +ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def _get_meta_tags( + tree: "etree.Element", + property: str, + prefix: str, + property_mapper: Optional[Callable[[str], Optional[str]]] = None, +) -> Dict[str, Optional[str]]: + """ + Search for meta tags prefixed with a particular string. + + Args: + tree: The parsed HTML document. + property: The name of the property which contains the tag name, e.g. + "property" for Open Graph. + prefix: The prefix on the property to search for, e.g. "og" for Open Graph. + property_mapper: An optional callable to map the property to the Open Graph + form. Can return None for a key to ignore that key. + + Returns: + A map of tag name to value. + """ + results: Dict[str, Optional[str]] = {} + for tag in tree.xpath( + f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" + ): + # if we've got more than 50 tags, someone is taking the piss + if len(results) >= 50: + logger.warning( + "Skipping parsing of Open Graph for page with too many '%s:' tags", + prefix, + ) + return {} + + key = tag.attrib[property] + if property_mapper: + key = property_mapper(key) + # None is a special value used to ignore a value. + if key is None: + continue + + results[key] = tag.attrib["content"] + + return results + + +def _map_twitter_to_open_graph(key: str) -> Optional[str]: + """ + Map a Twitter card property to the analogous Open Graph property. + + Args: + key: The Twitter card property (starts with "twitter:"). + + Returns: + The Open Graph property (starts with "og:") or None to have this property + be ignored. + """ + # Twitter card properties with no analogous Open Graph property. + if key == "twitter:card" or key == "twitter:creator": + return None + if key == "twitter:site": + return "og:site_name" + # Otherwise, swap twitter to og. + return "og" + key[7:] + + +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + + Returns: + The Open Graph response as a dictionary. + """ + + # Search for Open Graph (og:) meta tags, e.g.: + # + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og = _get_meta_tags(tree, "property", "og") + + # TODO: Search for properties specific to the different Open Graph types, + # such as article: meta tags, e.g.: + # + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + # Search for Twitter Card (twitter:) meta tags, e.g.: + # + # "twitter:site" : "@matrixdotorg" + # "twitter:creator" : "@matrixdotorg" + # + # Twitter cards tags also duplicate Open Graph tags. + # + # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started + twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph) + # Merge the Twitter values with the Open Graph values, but do not overwrite + # information from Open Graph tags. + for key, value in twitter.items(): + if key not in og: + og[key] = value + + if "og:title" not in og: + # Attempt to find a title from the title tag, or the biggest header on the page. + title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + if title: + og["og:title"] = title[0].strip() + else: + og["og:title"] = None + + if "og:image" not in og: + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" + ) + # If a meta image is found, use it. + if meta_image: + og["og:image"] = meta_image[0] + else: + # Try to find images which are larger than 10px by 10px. + # + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + # If no images were found, try to find *any* images. + if not images: + images = tree.xpath("//img[@src][1]") + if images: + og["og:image"] = images[0].attrib["src"] + + # Finally, fallback to the favicon if nothing else. + else: + favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + if favicons: + og["og:image"] = favicons[0] + + if "og:description" not in og: + # Check the first meta description tag for content. + meta_description = tree.xpath( + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" + ) + # If a meta description is found with content, use it. + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the tag, unless they are within + an HTML5 semantic markup tag (
,